From ab9d0c0c1263068dc786f75549f2f8499244a872 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 31 Jul 2024 19:53:38 +0000 Subject: [PATCH] squash-patch changes move heuristic into C++ code fix unit tests + format update for 3.5.1 remove custom scheduler codespell cleanup comment cleanup diff review comments review comments review comment changes review comments fix codespell cleanup util logic make dim names for prepack layout more canoncial missed refactor wip interleaving + recasting tweak tolerances comments plus interleaving format codespell review comments end2end first pass seperate out kernels, format add machete as a gptq backend update to use ModelWeightParameter formatting update parameter.py refactor permute layout wip --- csrc/cutlass_extensions/cute_utils.cuh | 14 +- csrc/cutlass_extensions/torch_utils.hpp | 7 +- .../vllm_numeric_conversion.cuh | 2 +- csrc/ops.h | 4 +- csrc/quantization/machete/generate.py | 2 - .../quantization/machete/machete_mainloop.cuh | 676 +++++++++--------- .../machete/machete_mm_kernel.cuh | 3 +- .../machete/machete_mm_launcher.cuh | 3 +- .../machete/machete_prepack_launcher.cuh | 2 +- .../machete/machete_prepacked_layout.cuh | 20 - csrc/quantization/machete/machete_pytorch.cu | 4 +- csrc/torch_bindings.cpp | 14 +- examples/offline_inference.py | 7 +- vllm/_custom_ops.py | 2 +- .../layers/quantization/awq_marlin.py | 9 +- .../schemes/compressed_tensors_wNa16.py | 95 +-- .../layers/quantization/gptq.py | 3 + .../layers/quantization/gptq_marlin.py | 99 +-- .../quantization/kernels/GPTQLinearKernel.py | 85 +++ .../quantization/kernels/MPLinearKernel.py | 79 ++ .../kernels/MacheteLinearKernel.py | 88 +++ .../kernels/MarlinLinearKernel.py | 128 ++++ .../layers/quantization/kernels/__init__.py | 44 ++ .../layers/quantization/utils/__init__.py | 3 + .../layers/quantization/utils/layer_utils.py | 33 + .../quantization/utils/machete_utils.py | 30 + .../layers/quantization/utils/marlin_utils.py | 29 +- vllm/model_executor/parameter.py | 62 ++ vllm/scalar_type.py | 2 + 29 files changed, 978 insertions(+), 571 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/kernels/GPTQLinearKernel.py create mode 100644 vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py create mode 100644 vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py create mode 100644 vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py create mode 100644 vllm/model_executor/layers/quantization/kernels/__init__.py create mode 100644 vllm/model_executor/layers/quantization/utils/layer_utils.py create mode 100644 vllm/model_executor/layers/quantization/utils/machete_utils.py diff --git a/csrc/cutlass_extensions/cute_utils.cuh b/csrc/cutlass_extensions/cute_utils.cuh index 1842fab8b2cac..114a14cd61b88 100644 --- a/csrc/cutlass_extensions/cute_utils.cuh +++ b/csrc/cutlass_extensions/cute_utils.cuh @@ -25,9 +25,8 @@ CUTE_HOST_DEVICE static constexpr bool is_identity_layout() { else { constexpr auto coalesced_layout = coalesce(Layout{}); if constexpr (rank(coalesced_layout) == 1 && - stride<0>(coalesced_layout) == 1) { + stride<0>(coalesced_layout) == 1) return true; - } return false; } } @@ -52,17 +51,16 @@ static constexpr auto get_logical_ptr(PointerType* ptr) { template CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() { constexpr auto bits = sizeof_bits_v * Elements{}; - if constexpr (bits % 128 == 0) { + if constexpr (bits % 128 == 0) return AutoVectorizingCopyWithAssumedAlignment<128>{}; - } else if constexpr (bits % 64 == 0) { + else if constexpr (bits % 64 == 0) return AutoVectorizingCopyWithAssumedAlignment<64>{}; - } else if constexpr (bits % 32 == 0) { + else if constexpr (bits % 32 == 0) return AutoVectorizingCopyWithAssumedAlignment<32>{}; - } else if constexpr (bits % 16 == 0) { + else if constexpr (bits % 16 == 0) return AutoVectorizingCopyWithAssumedAlignment<16>{}; - } else { + else return AutoVectorizingCopyWithAssumedAlignment<8>{}; - } } }; // namespace cute diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp index 1618a340ce10e..ec8b21a62f894 100644 --- a/csrc/cutlass_extensions/torch_utils.hpp +++ b/csrc/cutlass_extensions/torch_utils.hpp @@ -17,7 +17,7 @@ namespace detail { template CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g, seq) { - return g(f(cute::get(static_cast(t)), I)...); + return g(f(get(static_cast(t)), I)...); } template @@ -29,7 +29,7 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq) { template CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) { - if constexpr (cute::is_tuple::value) { + if constexpr (is_tuple::value) { return detail::tapply_with_idx( t, f, [](auto const&... a) { return cute::make_tuple(a...); }, tuple_seq{}); @@ -72,9 +72,8 @@ static inline auto make_cute_layout(torch::Tensor const& tensor, } } else { // Extra strides are assumed to be 0 or 1 - if constexpr (cute::is_static_v) { + if constexpr (cute::is_static_v) static_assert(StrideEle::value == 0 || StrideEle::value == 1); - } return StrideEle{}; } }); diff --git a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh index 2ad914f8e9868..7561b2505b10e 100644 --- a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh +++ b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh @@ -524,7 +524,7 @@ struct NumericArrayConverter { // Below constructs the following temporary: uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; static_assert(RegArray::kElements <= 4, - "Too many inputs for uint4b8_t -> BF16 vector converter"); + "Too many inputs for BF16 -> I4 vector converter"); CUTLASS_PRAGMA_UNROLL for (int ii = 0; ii < RegArray::kElements; ++ii) { asm volatile( diff --git a/csrc/ops.h b/csrc/ops.h index 6bf0cff232528..b7f07f5da7d1e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -88,7 +88,7 @@ namespace machete { std::vector supported_schedules( vllm::ScalarTypeTorchPtr const& btype); -torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, +torch::Tensor gemm(torch::Tensor const A, torch::Tensor const B, vllm::ScalarTypeTorchPtr const& btype, c10::optional const& scales, c10::optional const& zeros, @@ -97,7 +97,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, c10::optional alpha, c10::optional beta, c10::optional schedule); -torch::Tensor prepack_B(torch::Tensor const& B, +torch::Tensor prepack_B(torch::Tensor const B, vllm::ScalarTypeTorchPtr const& btype); }; // namespace machete diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 09a98a5dd1fd6..a98e24c5672d7 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -324,8 +324,6 @@ def create_sources(impl_config: ImplConfig, num_impl_files=2): def generate(): - # See csrc/quantization/machete/Readme.md, the Codegeneration for more info - # about how this works SCRIPT_DIR = os.path.dirname(__file__) schedules = [ diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh index 3d574ad99efda..4ea7e4e631291 100644 --- a/csrc/quantization/machete/machete_mainloop.cuh +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -1,27 +1,5 @@ -// // Based off of: -// cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp -// Specifically: -// https://github.com/NVIDIA/cutlass/tree/06b21349bcf6ddf6a1686a47a137ad1446579db9/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp -// Referred to as upstream from in the comments -// -// The main optimization machete implements compared to upstream is to prepack -// the weight matrix to more closely match the shape of the wgmma instructions -// allowing for wider (ideally 128bit) shared memory loads. For subbyte types -// this is done by packing values from multiple wgmma loads (for a single -// thread) into a single 128bit load. This is very similar to layout used in -// Marlin, although specific to the wgmma instructions. -// -// Since the wgmma instructions only support sourcing from registers for the A -// operand, and we want to upconvert/decompress the weight values/elements -// before feeding them into the tensor cores in registers, we need the weight -// matrix to be A. To achieve this we compute the transpose of Y = XW^t as -// Y^t = W^tX^t. This is mostly done outside of this file in -// csrc/quantization/machete/machete_mm_kernel.cuh, but this why A is the -// quantized/narrow type and has the prepacked layout despite the API being: -// B_prepacked = machete_prepack_B(B) -// Y = machete_mm(A, B_prepacked) -// +// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp #pragma once // clang-format off @@ -49,6 +27,8 @@ #include "cutlass_extensions/cute_utils.cuh" +///////////////////////////////////////////////////////////////////////////////////////////////// + namespace machete { using namespace cute; @@ -57,6 +37,9 @@ using namespace cutlass::gemm; using namespace cutlass::gemm::collective; using namespace cutlass::gemm::collective::detail; +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop that source A operand from registers template (TileShape_MNK{}) / size<0>(PPBlockShape_MK{}), size<2>(TileShape_MNK{}) / size<1>(PPBlockShape_MK{}))); @@ -109,37 +92,18 @@ struct MacheteCollectiveMma { gmma_rs_tag_to_major_A(); static constexpr cute::GMMA::Major GmmaMajorB = gmma_rs_tag_to_major_B(); - - // For coop schedules we have two warp groups cooperatively issuing wgmma - // instructions so we use 2 atoms along the M dim (one for each warpgroup) using AtomLayoutMNK = cute::conditional_t< cute::is_same_v, Layout>, Layout>>; + // Required by kernel using TiledMma = decltype(cute::make_tiled_mma( cute::GMMA::rs_op_selector(), AtomLayoutMNK{})); private: - // - // the setup section (until "section setup end") contains a combination of - // modified code from (used as a starting point): - // `cutlass/gemm/collective/builders/sm90_gmma_builder.inl` - // `cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp` - // (upstream) - // - // however in-order to simplify the code we combine a lot of the logic from - // `CollectiveMma` and `CollectiveBuilder` into this class, this also makes - // sense given that we have flexibility on layouts here. We also simplify the - // code by only supporting scales and zeros for A (in the transposed problem, - // B from an API perspective), also since we force A to be the narrow type - // (i.e. the type to be upconverted) we can remove all the `SwapAB` logic in - // the upstream also simplifying the code. This section includes new logic - // (compared ustream) for handling the prepacked-A layouts (in the transposed - // problem, B from an API perspective) - // using ElementScale = deduce_mixed_width_dtype_t<1, ElementATuple_>; using ElementZero = deduce_mixed_width_dtype_t<2, ElementATuple_>; @@ -155,6 +119,7 @@ struct MacheteCollectiveMma { sm90_smem_capacity_bytes, ElementA, ElementB, ElementScale, ElementZero, TileShape_MNK>(StageCountType{}); + // Required by kernel struct DispatchPolicy { constexpr static int Stages = PipelineStages; using ClusterShape = ClusterShape_MNK; @@ -347,7 +312,7 @@ struct MacheteCollectiveMma { static_assert(size<1>(SmemLayoutAtomScale{}) == 1, "size<1>(SmemLayoutAtomScale) must be 1."); - private: + public: // TODO: make private static constexpr ConversionMode get_conversion_mode() { if constexpr (cute::is_void_v) { return ConversionMode::DirectConvert; @@ -363,7 +328,6 @@ struct MacheteCollectiveMma { KernelConversionMode == ConversionMode::ConvertAndScale || KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; - // Same as upstream, should be kept the same when possible static constexpr auto elements_per_smem_scale() { if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { return 0; @@ -375,7 +339,6 @@ struct MacheteCollectiveMma { } } - // Same as upstream, should be kept the same when possible static constexpr auto elements_per_smem_zero() { if constexpr (KernelConversionMode == ConversionMode::DirectConvert || KernelConversionMode == ConversionMode::ConvertAndScale) { @@ -389,43 +352,49 @@ struct MacheteCollectiveMma { } } - // Same as upstream, should be kept the same when possible, not formatte for - // easier comparison - // clang-format off - // These methods use some the public members of the class. For that reason, we define them after the public section. - static constexpr uint32_t - compute_tma_transaction_bytes_mk() { - constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); + // These methods use some the public members of the class. For that reason, we + // define them after the public section. + static constexpr uint32_t compute_tma_transaction_bytes_mk() { + constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes( + size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * + static_cast(cute::sizeof_bits_v)); if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { return baseline_bytes; - } - else if constexpr (ModeHasScales) { - constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); - static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA + } else if constexpr (ModeHasScales) { + constexpr uint32_t scale_tx_bytes = + (size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * + static_cast(cute::sizeof_bits_v) / 8); + static_assert( + scale_tx_bytes % 128 == 0, + "Each scale stage must be 128B aligned."); // required by TMA if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return baseline_bytes + scale_tx_bytes; - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { // Scale and zero share smem layout - constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); - static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA + constexpr uint32_t zero_tx_bytes = + (size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * + static_cast(cute::sizeof_bits_v) / 8); + static_assert( + zero_tx_bytes % 128 == 0, + "Each zero stage must be 128B aligned."); // required by TMA return baseline_bytes + scale_tx_bytes + zero_tx_bytes; + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in tma transaction bytes computation."); } - else { - static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); - } - } - else { - static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in tma transaction bytes computation."); } } - static constexpr uint32_t - compute_tma_transaction_bytes_nk() { - return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); + static constexpr uint32_t compute_tma_transaction_bytes_nk() { + return cutlass::bits_to_bytes( + size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * + static_cast(cute::sizeof_bits_v)); } - // clang-format on // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset( @@ -479,26 +448,29 @@ struct MacheteCollectiveMma { } public: - // Same as upstream, should be kept the same when possible, not formatted for - // easier comparison - // with `RealInternalElementA` -> `ElementA` since we support `SwapAB` logic - // clang-format off - static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + static constexpr size_t SmemAlignmentA = + cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); - static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + static constexpr size_t SmemAlignmentB = + cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); - // Just pick the max alignment of A and B since it is required to be at least 128B - static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + // Just pick the max alignment of A and B since it is required to be at least + // 128B + static constexpr size_t SmemAlignmentScale = + cute::max(SmemAlignmentA, SmemAlignmentB); - static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, + "Require at least 128B alignment"); - struct SharedStorage - { + struct SharedStorage { static constexpr int scale_elements = elements_per_smem_scale(); static constexpr int zero_elements = elements_per_smem_zero(); - struct TensorStorage : cute::aligned_struct { + struct TensorStorage + : cute::aligned_struct { cute::ArrayEngine> smem_A; - cute::ArrayEngine> smem_B; + cute::ArrayEngine> + smem_B; cute::ArrayEngine smem_scale; cute::ArrayEngine smem_zero; } tensors; @@ -506,6 +478,7 @@ struct MacheteCollectiveMma { using PipelineStorage = typename MainloopPipeline::SharedStorage; PipelineStorage pipeline; }; + using TensorStorage = typename SharedStorage::TensorStorage; using PipelineStorage = typename SharedStorage::PipelineStorage; @@ -521,16 +494,7 @@ struct MacheteCollectiveMma { ElementZero const* ptr_Z = nullptr; uint32_t mma_promotion_interval = 4; }; - // clang-format on - // - // section setup end - // - - // Similar (but not idendtical) to upstream, should be kept the same when - // possible - // compared to upstream we use `make_tma_copy_A`, `make_tma_copy_B` etc. to - // define the TMA types // Device side kernel params struct Params { public: @@ -540,8 +504,6 @@ struct MacheteCollectiveMma { using TMA_Zero = decltype(make_tma_copy_zero()); using TMA_B = decltype(make_tma_copy_B()); - // required by outer loop: i.e. - // cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp TMA_A tma_load_a; TMA_B tma_load_b; TMA_Scale tma_load_scale; @@ -557,10 +519,6 @@ struct MacheteCollectiveMma { // Methods // - // Similar (but not idendtical) to upstream, should be kept the same when - // possible - // compared to upstream we use `make_tma_copy_A` and `TVbNbKL_to_offset` here - // to handle the prepacked layout template static constexpr Params to_underlying_arguments( ProblemShape const& problem_shape, Arguments const& args, @@ -615,100 +573,116 @@ struct MacheteCollectiveMma { } } - // Same as upstream, should be kept the same when possible, not formatted for - // easier comparison - // with `SwapAB ? N : M -> M` since we dont support SwapAB - // clang-format off - template - static bool - can_implement( + template + CUTLASS_HOST_DEVICE static bool can_implement( ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { constexpr int tma_alignment_bits = 128; auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - + auto [M, N, K, L] = problem_shape_MNKL; + bool implementable = true; - constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); - constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + constexpr int min_tma_aligned_elements_A = + tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = + implementable && + cutlass::detail::check_alignment( + cute::make_shape(M, K, L), StrideA{}); + constexpr int min_tma_aligned_elements_B = + tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = + implementable && + cutlass::detail::check_alignment( + cute::make_shape(N, K, L), StrideB{}); if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { implementable = implementable && (args.ptr_S == nullptr); implementable = implementable && (args.ptr_Z == nullptr); - } - else if constexpr (ModeHasScales) { + } else if constexpr (ModeHasScales) { const int scale_mn = M; const int scale_k = (K + args.group_size - 1) / args.group_size; - constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); - implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0)); + constexpr int min_tma_aligned_elements_scale = + tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = + implementable && + cutlass::detail::check_alignment( + cute::make_shape(scale_mn, scale_k, L), StrideScale{}); + implementable = + implementable && (args.group_size == K || + ((args.group_size % size<2>(TileShape{})) == 0)); implementable = implementable && args.group_size != 0; implementable = implementable && (args.ptr_S != nullptr); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { implementable = implementable && (args.ptr_Z == nullptr); - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = + tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = + implementable && + cutlass::detail::check_alignment( + cute::make_shape(scale_mn, scale_k, L), StrideScale{}); implementable = implementable && (args.ptr_Z != nullptr); - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in can_implement."); } - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in can_implement."); } if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment " + "requirements for TMA.\n"); } return implementable; } static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; - static constexpr uint32_t TmaTransactionBytesMK = compute_tma_transaction_bytes_mk(); - static constexpr uint32_t TmaTransactionBytesNK = compute_tma_transaction_bytes_nk(); - static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + static constexpr uint32_t TmaTransactionBytesMK = + compute_tma_transaction_bytes_mk(); + static constexpr uint32_t TmaTransactionBytesNK = + compute_tma_transaction_bytes_nk(); + static constexpr uint32_t TmaTransactionBytes = + TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best + /// performance CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& mainloop_params) { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_b.get_tma_descriptor()); if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { // Nothing extra to do - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_zero.get_tma_descriptor()); - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA prefetch."); + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScale) { + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_scale.get_tma_descriptor()); + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_scale.get_tma_descriptor()); + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_zero.get_tma_descriptor()); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in TMA prefetch."); } - } - // clang-format off - // Modified from upstream, should be kept close to that when possible - // the main difference is special handling for the prepacked A layout - // - // Set up the data needed by this collective for load and mma. - // Returns a tuple of tensors. The collective and the kernel layer have the - // contract Returned tuple must contain at least two elements, with the first - // two elements being: gA_mkl - The tma tensor, A after a local tile so it - // has shape (TILE_V,TILE_B,m,k,l) gB_nkl - The tma tensor, B after a local - // tile so it has shape (TILE_N,TILE_K,n,k,l) The rest of the tensors can be - // specified as needed by this collective. - // NOTE: TILE_B is the prepacked block index within a tile. TILE_V is the - // values within a prepacked block. + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the + /// contract Returned tuple must contain at least two elements, with the first + /// two elements being: gA_mkl - The tma tensor, A after a local tile so it + /// has shape (BLK_M,BLK_K,m,k,l) gB_nkl - The tma tensor, B after a local + /// tile so it has shape (BLK_N,BLK_K,n,k,l) The rest of the tensors can be + /// specified as needed by this collective. template CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { @@ -765,72 +739,69 @@ struct MacheteCollectiveMma { } } - // Similar to upstream, should be kept close to that when possible - // the main difference is in the layout comments - // clang-format off /// Perform a collective-scoped matrix multiply-accumulate /// Producer Perspective /// This overload gets triggered when we have scales. - template < - class... Ts, - class KTileIterator, class BlockCoord - > - CUTLASS_DEVICE void - load( - Params const& mainloop_params, - MainloopPipeline pipeline, - PipelineState smem_pipe_write, - cute::tuple const& load_inputs, - BlockCoord const& blk_coord, - KTileIterator k_tile_iter, int k_tile_count, - int thread_idx, - uint32_t block_rank_in_cluster, - TensorStorage& shared_tensors) { + template + CUTLASS_DEVICE void load(Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs"); - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs"); - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs"); - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); + static_assert(sizeof...(Ts) == 2, "Direct convert needs two inputs"); + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScale) { + static_assert(sizeof...(Ts) == 3, "Scaled convert needs three inputs"); + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + static_assert(sizeof...(Ts) == 4, + "Scaled and zero convert needs four inputs"); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in TMA load."); } int lane_predicate = cute::elect_one_sync(); if (lane_predicate) { - Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) - Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), + SmemLayoutACopy{}); // (TILE_V,TILE_B,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), + SmemLayoutB{}); // (TILE_N,TILE_K,PIPE) + Tensor sB = + as_position_independent_swizzle_tensor(sB_); // (TILE_N,TILE_K,PIPE) // // Prepare the TMA loads for A, B and Scales // - + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, + block_rank_in_cluster / cluster_shape_x}; Tensor gA_mkl = get<0>(load_inputs); Tensor gB_nkl = get<1>(load_inputs); - auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_a = + mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = + mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); // Partition the inputs based on the current block coordinates. auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (TILE_V,TILE_B,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (TILE_N,TILE_K,k) + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (TILE_V,TILE_B,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (TILE_N,TILE_K,k) // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) uint16_t mcast_mask_a = 0; uint16_t mcast_mask_b = 0; @@ -839,24 +810,32 @@ struct MacheteCollectiveMma { // Issue TmaLoads // Maps the tile -> block, value if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id + auto block_layout = + Layout{}; // (m,n) -> + // block_id for (int n = 0; n < size<1>(block_layout); ++n) { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, + n, Int<0>{})); } } if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id + auto block_layout = + Layout{}; // (m,n) -> + // block_id for (int m = 0; m < size<0>(block_layout); ++m) { - mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + mcast_mask_b |= (uint16_t(1) << block_layout( + m, cluster_local_block_id.y, Int<0>{})); } } - auto extra_input_partitions = partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); + auto extra_input_partitions = partition_extra_tma_inputs( + mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, + m_coord, l_coord); // Mainloop CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) { + for (; k_tile_count > 0; --k_tile_count) { // LOCK smem_pipe_write for _writing_ pipeline.producer_acquire(smem_pipe_write); @@ -865,41 +844,51 @@ struct MacheteCollectiveMma { // using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + BarrierType* tma_barrier = + pipeline.producer_get_barrier(smem_pipe_write); int write_stage = smem_pipe_write.index(); - copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); - copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), + tAgA(_, _, _, *k_tile_iter), tAsA(_, _, _, write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), + tBgB(_, _, _, *k_tile_iter), tBsB(_, _, _, write_stage)); if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { // Nothing extra to do. - } - else if constexpr (ModeHasScales) { + } else if constexpr (ModeHasScales) { auto tSgS = get<0>(extra_input_partitions); auto tSsS = get<1>(extra_input_partitions); - // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes - // on the fly. - // We must do a ceiling divide here to correctly handle with group_size == K. In that case, we don't require that K - // is a multiple of the threadblock tile K - const int ReloadFactor = (mainloop_params.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); - const int scale_load_k = *k_tile_iter / ReloadFactor; // This will always be 0 when group_size == K. - copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage)); - - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Temporary factor which will determine which k tile to reload from + // gmem. Needed so we don't modify tma transaction bytes on the fly. + // We must do a ceiling divide here to correctly handle with + // group_size == K. In that case, we don't require that K is a + // multiple of the threadblock tile K + const int ReloadFactor = + (mainloop_params.group_size + size<2>(TileShape{}) - 1) / + size<2>(TileShape{}); + const int scale_load_k = + *k_tile_iter / + ReloadFactor; // This will always be 0 when group_size == K. + copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), + tSgS(_, _, _, scale_load_k), tSsS(_, _, _, write_stage)); + + if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScale) { // Nothing extra to do - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { auto tZgZ = get<2>(extra_input_partitions); auto tZsZ = get<3>(extra_input_partitions); - copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage)); + copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), + tZgZ(_, _, _, scale_load_k), tZsZ(_, _, _, write_stage)); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for TMA copy op."); } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); - } - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for TMA copy op."); } ++k_tile_iter; @@ -909,35 +898,26 @@ struct MacheteCollectiveMma { } } } - // clang-format off - - // Same as upstream, should be kept the same when possible, not formatted for - // easier comparison - // clang-format off - // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster - CUTLASS_DEVICE void - load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, + PipelineState smem_pipe_write) { int lane_predicate = cute::elect_one_sync(); // Issue the epilogue waits if (lane_predicate) { /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all + * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was + * then would just be acquired since the phase was * still inverted from make_producer_start_state */ pipeline.producer_tail(smem_pipe_write); } } - // clang-format on - // Modified from upstream, should be kept close to that when possible - // the main differences are handling the prepacked A layout, and separating - // the loading of A from upcoverting A - // - // Perform a collective-scoped matrix multiply-accumulate - // Consumer Perspective + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective template CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum, @@ -1171,7 +1151,7 @@ struct MacheteCollectiveMma { warpgroup_fence_operand(accum); } - // Perform a Consumer Epilogue to release all buffers + /// Perform a Consumer Epilogue to release all buffers CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { @@ -1193,134 +1173,118 @@ struct MacheteCollectiveMma { } private: - // Same as upstream, should be kept the same when possible, not formatted for - // easier comparison - // clang-format off /// Utilities for any additional inputs inside of the TMA load template - CUTLASS_DEVICE - auto partition_extra_tma_inputs( - Params const& mainloop_params, - cute::tuple const& load_inputs, - TensorStorage& shared_tensors, - uint2 const& cluster_local_block_id, - int const m_coord, - int const l_coord) { - + CUTLASS_DEVICE auto partition_extra_tma_inputs( + Params const& mainloop_params, cute::tuple const& load_inputs, + TensorStorage& shared_tensors, uint2 const& cluster_local_block_id, + int const m_coord, int const l_coord) { if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - return cute::make_tuple(); - } - else if constexpr (ModeHasScales) { - Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + return cute::tuple{}; + } else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), + SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) Tensor gS_mkl = get<2>(load_inputs); - auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); - Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + auto block_tma_s = + mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); + Tensor gS = gS_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) - Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) - Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) + Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) + Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(tSgS, tSsS); - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), + SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) Tensor gZ_mkl = get<3>(load_inputs); - auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); - Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + auto block_tma_z = + mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); + Tensor gZ = gZ_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) - Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) - Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) - return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) + Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) + return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for input partitioning."); } - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for input partitioning."); } } - // clang-format off - // Same as upstream, should be kept the same when possible, not formatted for - // easier comparison - // clang-format off - /// Utilities for partitioning extra inputs for loading from smem in the mainloop. + /// Utilities for partitioning extra inputs for loading from smem in the + /// mainloop. template - CUTLASS_DEVICE - auto partition_extra_mma_info( - ThreadMma const& mma_thread_slice, - TensorStorage& shared_tensors) { - + CUTLASS_DEVICE auto partition_extra_mma_info(ThreadMma const& thread_mma, + TensorStorage& shared_tensors) { if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - // nothing to do - return cute::make_tuple(); - } - else if constexpr (ModeHasScales) { - Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) - Tensor tCsS = mma_thread_slice.partition_A(sS); - Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).shape()); + // noting to do + return cute::tuple{}; + } else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), + SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = thread_mma.partition_A(sS); + Tensor tCrS = make_tensor( + thread_mma.partition_fragment_A(sS(_, _, Int<0>{})).shape()); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(tCsS, tCrS); - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) - Tensor tCsZ = mma_thread_slice.partition_A(sZ); - Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).shape()); + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), + SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsZ = thread_mma.partition_A(sZ); + Tensor tCrZ = make_tensor( + thread_mma.partition_fragment_A(sZ(_, _, Int<0>{})).shape()); return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in A -> RF path."); } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); - } - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in A -> RF path."); } } - // clang-format on - // Same as upstream, should be kept the same when possible, not formatted for - // easier comparison - // clang-format off /// Returns the tiled copy and copy views for the extra inputs. template - CUTLASS_DEVICE - auto retile_extra_mma_info( - TiledMma const& tiled_mma, - cute::tuple& partitioned_extra_info, - int const warp_group_thread_idx) { - + CUTLASS_DEVICE auto retile_extra_mma_info( + TiledMma const& tiled_mma, cute::tuple& partitioned_extra_info, + int const warp_group_thread_idx) { if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - // nothing to do - return cute::make_tuple(); - } - else if constexpr (ModeHasScales) { - auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); - auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); - Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) - + // noting to do + return cute::tuple{}; + } else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = + make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); + auto smem_thr_copy_S = + smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); + Tensor tCrS_copy_view = smem_thr_copy_S.retile_D( + cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view); - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) - return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view); - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D( + cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, + tCrZ_copy_view); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in A -> RF path."); } - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in A -> RF path."); } } - // clang-format on - // Similar to `copy_A_and_extra_info` upstream, should be kept the same when - // possible - // the main differences this only loads the extra info into registers and - // not A (since we now preload more of A in the main pipeline) - // Load scales and zeros into registers if required + /// Utilities to copy A and extra inputs from smem to RF template CUTLASS_DEVICE void load_extra_info_to_registers( cute::tuple const& partitioned_mma_extra_info, @@ -1355,11 +1319,7 @@ struct MacheteCollectiveMma { } } - // Similar to upstream, should be kept the same when possible. - // the main differences are that `convert_tensor` supports interleaved - // layouts and bfloat16 has been optimized. `transform_internal_A` has also - // been inlined for code simplicity. - // Utilities to transform A. + /// Utilities to transform A. template CUTLASS_DEVICE void transform_A_kblock( TCrA_load const& tCrA_load, cute::Int vec_A, @@ -1418,9 +1378,7 @@ struct MacheteCollectiveMma { } } - // Modified from upstream, should be kept the same when possible - // the main differences is that this version supports interleaved converts - // Utilities for transforming the A operand prior to issuing tensorcore math. + /// Utilities for transforming the A operand prior to issuing tensorcore math. template > @@ -1428,12 +1386,12 @@ struct MacheteCollectiveMma { Tensor const& in, Tensor& out, cute::Int width = {}) { - // This is an element-wise conversion where we expect both tensors to have - // the same layout. As a result, we can cast as a cutlass array to use the - // fast numeric converters without worrying about indexing into the layout. + /// This is an element-wise conversion where we expect both tensors to have + /// the same layout. As a result, we can cast as a cutlass array to use the + /// fast numeric converters without worrying about indexing into the layout. constexpr int N = cosize_v; - // The inputs must be backed by registers & be statically sized. + /// The inputs must be backed by registers & be statically sized. static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); static_assert(is_rmem::value, @@ -1470,4 +1428,8 @@ struct MacheteCollectiveMma { } }; +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace machete + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/quantization/machete/machete_mm_kernel.cuh b/csrc/quantization/machete/machete_mm_kernel.cuh index 046e6e5a53652..34589c5bdb3c8 100644 --- a/csrc/quantization/machete/machete_mm_kernel.cuh +++ b/csrc/quantization/machete/machete_mm_kernel.cuh @@ -152,7 +152,8 @@ struct MacheteKernelTemplate { int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A); - int const group_size = maybe_group_size.value_or(K); + int group_size = maybe_group_size.value_or(K); + group_size = (group_size == -1) ? K : group_size; int const scale_k = (K + group_size - 1) / group_size; TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); diff --git a/csrc/quantization/machete/machete_mm_launcher.cuh b/csrc/quantization/machete/machete_mm_launcher.cuh index e2604d4bed3e2..9d404f138a963 100644 --- a/csrc/quantization/machete/machete_mm_launcher.cuh +++ b/csrc/quantization/machete/machete_mm_launcher.cuh @@ -49,8 +49,7 @@ torch::Tensor run_impl(PyTorchArguments args) { torch::empty({M, N}, torch::TensorOptions() .dtype(equivalent_scalar_type_v) .device(device)); - - auto const &A = args.A, &B = args.B; + *auto const &A = args.A, &B = args.B; auto const &C = args.C, &scales = args.scales, &zeros = args.zeros; auto layout_A = make_cute_layout(A, "A"); diff --git a/csrc/quantization/machete/machete_prepack_launcher.cuh b/csrc/quantization/machete/machete_prepack_launcher.cuh index 686dd68bd52bb..df78312997fb0 100644 --- a/csrc/quantization/machete/machete_prepack_launcher.cuh +++ b/csrc/quantization/machete/machete_prepack_launcher.cuh @@ -53,7 +53,7 @@ torch::Tensor prepack_impl(torch::Tensor const B) { // clang-format on // Allocate output - torch::Tensor D = torch::empty_like(B); + torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous); prepack_B(stream, B_ptr, layout_Bt, static_cast(D.mutable_data_ptr())); diff --git a/csrc/quantization/machete/machete_prepacked_layout.cuh b/csrc/quantization/machete/machete_prepacked_layout.cuh index 78e2cc5eec7d8..b307341f6f16c 100644 --- a/csrc/quantization/machete/machete_prepacked_layout.cuh +++ b/csrc/quantization/machete/machete_prepacked_layout.cuh @@ -30,17 +30,6 @@ using namespace cute; struct IlvBlkLayoutAuto {}; -// This defines a prepacked layout for the B matrix, where the matrix is broken -// up into PPBlockShape_NK blocks. The data within each block is then compactly -// stored in memory such that when performing a TiledMMA operation with the same -// shape as prepacked block, all the data for a given thread is contiguous in -// memory. This allows us to use wider shared memory loads when loading B from -// shared memory. The values within a thread are also potentially interlaeved -// inorder to allow for more efficient upconverting. -// -// The contract here is that the `TiledMma` determined below matches the one -// ultimately used in the kernel. (this is also why the other element types are -// required along with the kernel schedule) template @@ -68,12 +57,6 @@ struct PrepackedLayoutBTemplate { // TODO (LucasWilkinson): compare the performance for other sizes // Prepacked block shape, smallest layout atom for loading into registers // (can contain multiple wgmma instructions worth of data in one block) - // We ideally want this to be configured such that a thread can perform 128bit - // loads, i.e. we amount of data associated with each thread within a - // prepacked block is a multiple of 128bits, when using a cooperative sechdule - // we have 256 threads working a single block at a time, this means each - // thread works on `sizeof_bits_v * (128*64) / 256` bits of data, - // for a 4bit type this would be 128bits using PPBlockShape_NK = Shape<_128, _64>; // Create the shape of the tile anticipated to be used by the GEMM kernel, @@ -87,9 +70,6 @@ struct PrepackedLayoutBTemplate { static constexpr cute::GMMA::Major GmmaMajorB = gmma_rs_tag_to_major_B(); - - // For coop schedules we have two warp groups cooperatively issuing wgmma - // instructions so we use 2 atoms along the M dim (one for each warpgroup) using AtomLayoutMNK = cute::conditional_t< cute::is_same_v, diff --git a/csrc/quantization/machete/machete_pytorch.cu b/csrc/quantization/machete/machete_pytorch.cu index ef36a490c3c50..0f68dfdcd0528 100644 --- a/csrc/quantization/machete/machete_pytorch.cu +++ b/csrc/quantization/machete/machete_pytorch.cu @@ -42,7 +42,7 @@ std::vector supported_schedules(ScalarTypeTorchPtr const& btype) { }); } -torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, +torch::Tensor gemm(torch::Tensor const A, torch::Tensor const B, ScalarTypeTorchPtr const& btype, c10::optional const& scales, c10::optional const& zeros, @@ -69,7 +69,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, }); } -torch::Tensor prepack_B(torch::Tensor const& B, +torch::Tensor prepack_B(torch::Tensor const B, ScalarTypeTorchPtr const& btype) { return scalar_type_dispatch(*btype, [&](auto BType) { return PrepackBDispatcher::dispatch(B); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 6d1f53b75f4e2..f4c8d406c671b 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -135,17 +135,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Machete (Dense) Optimized Mixed Precision GEMM for Hopper. ops.def("machete_supported_schedules", &machete::supported_schedules); - ops.def( - "machete_gemm(Tensor A, Tensor B," - " __torch__.torch.classes._core_C.ScalarType btype," - " Tensor? scales, Tensor? zeros, int? group_size," - " Tensor? C, float? alpha, float? beta, str? schedule)" - "-> Tensor"); + ops.impl("machete_supported_schedules", torch::kCPU, + &machete::supported_schedules); + ops.def("machete_gemm", &machete::gemm); ops.impl("machete_gemm", torch::kCUDA, &machete::gemm); - ops.def( - "machete_prepack_B(Tensor B," - " __torch__.torch.classes._core_C.ScalarType btype)" - "-> Tensor"); + ops.def("machete_prepack_B", &machete::prepack_B); ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B); // gptq_marlin Optimized Quantized GEMM for GPTQ. diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f6..df303229118b9 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,12 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +# GPTQ = "kaitchup/Llama-2-7b-gptq-3bit" +# marlin = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ" +# machete = "TheBloke/Llama-2-7B-GPTQ" +# machete/marlin CT = "nm-testing/tinyllama-oneshot-w4a16-group128-v2" +# "nm-testing/tinyllama-oneshot-w4a16-channel-v2" +llm = LLM(model="nm-testing/tinyllama-oneshot-w4a16-channel-v2") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index b89a90ef0f70c..1c285cfbfb20f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -336,7 +336,7 @@ def machete_supported_schedules(b_type: ScalarType) -> List[str]: def machete_gemm( a: torch.Tensor, - b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, b_type: ScalarType, b_scales: Optional[torch.Tensor] = None, b_zeros: Optional[torch.Tensor] = None, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index eee6a8f7cff49..ba1e519d64370 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -7,10 +7,11 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) @@ -231,7 +232,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qweight", marlin_qweight) + replace_parameter(layer, "qweight", marlin_qweight) # Permute scales from AWQ format to marlin format. marlin_scales = marlin_permute_scales( @@ -239,7 +240,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, group_size=self.quant_config.group_size) - replace_tensor(layer, "scales", marlin_scales) + replace_parameter(layer, "scales", marlin_scales) # Permute zero-points from AWQ format to marlin format. marlin_zp = awq_to_marlin_zero_points( @@ -247,7 +248,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.num_groups, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qzeros", marlin_zp) + replace_parameter(layer, "qzeros", marlin_zp) # Not-used layer.g_idx = marlin_make_empty_g_idx(device) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 7ca8eecb9283e..17fd688a09f10 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -2,13 +2,10 @@ import torch -from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) -from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, replace_tensor, verify_marlin_supported, - verify_marlin_supports_shape) +from vllm.model_executor.layers.quantization.kernels import ( + MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -46,23 +43,32 @@ def __init__(self, self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] - # Verify supported on platform. - verify_marlin_supported(quant_type=self.quant_type, - group_size=self.group_size) - @classmethod def get_min_capability(cls) -> int: # ampere and up return 80 - def create_weights(self, layer: torch.nn.Module, input_size: int, - output_partition_sizes: List[int], + def create_weights(self, layer: torch.nn.Module, output_size: int, + input_size: int, output_partition_sizes: List[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): output_size_per_partition = sum(output_partition_sizes) + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_type, + act_type=params_dtype, + group_size=self.group_size, + zero_points=False, + act_reordering=False + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + # If group_size is -1, we are in channelwise case. channelwise = (self.group_size == -1) group_size = self.group_size if self.group_size != -1 else input_size @@ -71,12 +77,6 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, # scales across all gpus. partition_scales = (row_parallel and not channelwise) - verify_marlin_supports_shape( - output_size_per_partition=output_size_per_partition, - input_size_per_partition=input_size_per_partition, - input_size=input_size, - group_size=group_size) - scales_and_zp_size = input_size // group_size if partition_scales: @@ -123,62 +123,17 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.input_size = input_size - layer.group_size = group_size + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name=None, + w_gidx_param_name=None) # Checkpoints are serialized in compressed-tensors format, which is - # different from marlin format. Handle repacking here. + # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - device = layer.weight_packed.device - - # Allocate marlin workspace. - layer.workspace = marlin_make_workspace( - layer.output_size_per_partition, device) - - # Act-order not supported in compressed-tensors yet, so set to empty. - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - # No zero-point - layer.weight_zp = marlin_make_empty_g_idx(device) - # Update for kernel - layer.weight_packed = torch.nn.Parameter( - layer.weight_packed.t().contiguous(), requires_grad=False) - layer.weight_scale = torch.nn.Parameter( - layer.weight_scale.squeeze().t().contiguous(), requires_grad=False) - - # Repack weights from compressed-tensors format to marlin format. - marlin_qweight = ops.gptq_marlin_repack( - layer.weight_packed, - perm=layer.g_idx_sort_indices, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - num_bits=self.quant_type.size_bits) - replace_tensor(layer, "weight_packed", marlin_qweight) - - # Permute scales from compressed-tensors format to marlin format. - marlin_scales = marlin_permute_scales( - layer.weight_scale, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - group_size=layer.group_size) - replace_tensor(layer, "weight_scale", marlin_scales) + self.kernel.process_weights_after_loading(layer) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: - - return apply_gptq_marlin_linear( - input=x, - weight=layer.weight_packed, - weight_scale=layer.weight_scale, - weight_zp=layer.weight_zp, - g_idx=layer.g_idx, - g_idx_sort_indices=layer.g_idx_sort_indices, - workspace=layer.workspace, - wtype=self.quant_type, - output_size_per_partition=layer.output_size_per_partition, - input_size_per_partition=layer.input_size_per_partition, - is_k_full=True, - bias=bias) + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index f456286899a53..90ed7de49f7e0 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -218,6 +218,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) + print(layer.qzeros) + print(hex(layer.qzeros[0][0].to(torch.uint32).item())) + def apply(self, layer: torch.nn.Module, x: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 94eb3f301541a..fed393dd97771 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,18 +1,16 @@ from typing import Any, Dict, List, Optional import torch -from torch.nn import Parameter -from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.kernels import ( + MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, - verify_marlin_supported, verify_marlin_supports_shape) + check_marlin_supported, marlin_repeat_scales_on_all_ranks, + verify_marlin_supported) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -163,24 +161,29 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: - - del output_size output_size_per_partition = sum(output_partition_sizes) is_row_parallel = input_size != input_size_per_partition weight_loader = extra_weight_attrs.get("weight_loader") + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_config.quant_type, + act_type=params_dtype, + group_size=self.quant_config.group_size, + zero_points=False, + act_reordering=self.quant_config.desc_act + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + # Normalize group_size if self.quant_config.group_size != -1: group_size = self.quant_config.group_size else: group_size = input_size - verify_marlin_supports_shape( - output_size_per_partition=output_size_per_partition, - input_size_per_partition=input_size_per_partition, - input_size=input_size, - group_size=group_size) - # Determine sharding if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, self.quant_config.group_size, @@ -261,55 +264,15 @@ def create_weights( layer.register_parameter("g_idx", g_idx) layer.register_parameter("scales", scales) layer.register_parameter("qzeros", qzeros) - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.input_size = input_size - layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act, - is_row_parallel) - - # Checkpoints are serialized in AutoGPTQ format, which is different from the - # marlin format. This function is called after the weights are loaded. - # Here, we handle the repacking, including the activation reordering case. - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - device = layer.qweight.device - - # required by torch.compile - layer.qweight = Parameter(layer.qweight.data, requires_grad=False) - layer.scales = Parameter(layer.scales.data, requires_grad=False) - # Allocate marlin workspace - layer.workspace = marlin_make_workspace( - layer.output_size_per_partition, device) + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="qweight", + w_s_param_name="scales", + w_zp_param_name="qzeros", + w_gidx_param_name="g_idx") - # Handle sorting for activation reordering if needed. - if self.quant_config.desc_act: - g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx) - layer.g_idx_sort_indices = g_idx_sort_indices - replace_tensor(layer, "g_idx", g_idx) - else: - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - # No zero-point - layer.zp = marlin_make_empty_g_idx(device) - - # Repack weights from autogptq format to marlin format. - marlin_qweight = ops.gptq_marlin_repack( - layer.qweight, - perm=layer.g_idx_sort_indices, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qweight", marlin_qweight) - - # Permute scales from autogptq format to marlin format. - marlin_scales = marlin_permute_scales( - layer.scales, - size_k=(layer.input_size if self.quant_config.desc_act else - layer.input_size_per_partition), - size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size) - replace_tensor(layer, "scales", marlin_scales) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) def apply( self, @@ -317,16 +280,4 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return apply_gptq_marlin_linear( - input=x, - weight=layer.qweight, - weight_scale=layer.scales, - weight_zp=layer.zp, - g_idx=layer.g_idx, - g_idx_sort_indices=layer.g_idx_sort_indices, - workspace=layer.workspace, - wtype=self.quant_config.quant_type, - output_size_per_partition=layer.output_size_per_partition, - input_size_per_partition=layer.input_size_per_partition, - is_k_full=layer.is_k_full, - bias=bias) + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/kernels/GPTQLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/GPTQLinearKernel.py new file mode 100644 index 0000000000000..d8b2de6141d63 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/GPTQLinearKernel.py @@ -0,0 +1,85 @@ +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.machete_utils import ( + MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, + query_machete_supported_quant_types) +from vllm.model_executor.parameter import (ModelWeightParameter, + PackedvLLMParameter) + +from .MPLinearKernel import * + + +class GPTQLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + + if c.act_type != torch.half: + return False, f"Act type {c.act_type} currently not supported by GPTQLinearKernel" + + if c.zero_points: + return False, "Zero points currently not supported by GPTQLinearKernel" + + if c.weight_type not in query_machete_supported_quant_types( + c.zero_points): + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Machete, supported types are: "\ + f"{query_machete_supported_quant_types(c.zero_points)}" + + if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Machete, supported group sizes are: "\ + f"{MACHETE_SUPPORTED_GROUP_SIZES}" + + return check_machete_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1]) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + + def transform_w_q(x): + # TODO (lucas): assert isinstance(x, PackedvLLMParameter) once + # everything is migrated to using weight_loader_v2 + if isinstance(x, PackedvLLMParameter): + x = x.permute_layout(input_dim=0, output_dim=1, packed_dim=0) + return ops.machete_prepack_B(x.t().contiguous().t(), + self.config.weight_type) + + def transform_w_s(x): + # TODO (lucas): assert isinstance(x, PackedvLLMParameter) once + # everything is migrated to using weight_loader_v2 + if isinstance(x, ModelWeightParameter): + x = x.permute_layout(input_dim=0, output_dim=1) + return x.contiguous() + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + output = ops.machete_gemm(a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_zeros=None, + b_scales=w_s, + b_group_size=c.group_size) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py new file mode 100644 index 0000000000000..185e40c251310 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py @@ -0,0 +1,79 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Optional, Tuple + +import torch + +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.scalar_type import ScalarType + + +@dataclass +class MPLinearLayerConfig: + full_weight_shape: Tuple[int, int] # [in, out] + partition_weight_shape: Tuple[int, int] + weight_type: ScalarType + act_type: torch.dtype + group_size: int + zero_points: bool + act_reordering: bool + + +class MPLinearKernel(ABC): + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + raise NotImplementedError + + def __init__(self, + c: MPLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + w_zp_param_name: Optional[str] = None, + w_gidx_param_name: Optional[str] = None) -> None: + assert self.can_implement(c) + self.config = c + self.w_q_name = w_q_param_name + self.w_s_name = w_s_param_name + self.w_zp_name = w_zp_param_name + self.w_gidx_name = w_gidx_param_name + + # note assumes that (if the they are not ModelWeightParameters) + # `getattr(layer, w_q_name)` is: + # {input_dim = 0, output_dim = 1, packed_dim = 0} + # `getattr(layer, w_s_name)` is: + # {input_dim = 0, output_dim = 1} + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + def _transform_param(self, layer: torch.nn.Module, name: Optional[str], + fn: Callable) -> None: + if name is not None and getattr(layer, name, None) is not None: + replace_parameter(layer, name, fn(getattr(layer, name))) + + def _get_weight_params( + self, layer: torch.nn.Module + ) -> Tuple[torch.Tensor, # w_q + torch.Tensor, # w_s + Optional[torch.Tensor], Optional[torch.Tensor]]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.w_zp_name or "", None), + getattr(layer, self.w_gidx_name or "", None), + ) diff --git a/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py new file mode 100644 index 0000000000000..275c767a8c1be --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py @@ -0,0 +1,88 @@ +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.machete_utils import ( + MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, + query_machete_supported_quant_types) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import * + + +class MacheteLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.act_reordering: + return False, "Act reordering currently not supported by Machete" + + if c.zero_points: + return False, "Zero points currently not supported by "\ + " Compressed Tensors + Machete. (Kernel supports it"\ + " but CompressedTensorsWNA16 does not so support has"\ + " not been added to MacheteWNA16Kernel yet" + + if c.weight_type not in query_machete_supported_quant_types( + c.zero_points): + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Machete, supported types are: "\ + f"{query_machete_supported_quant_types(c.zero_points)}" + + if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Machete, supported group sizes are: "\ + f"{MACHETE_SUPPORTED_GROUP_SIZES}" + + return check_machete_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1]) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), + self.config.weight_type) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous() + return x + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + print(w_s) + print(c.group_size) + + output = ops.machete_gemm(a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_zeros=None, + b_scales=w_s, + b_group_size=c.group_size) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py new file mode 100644 index 0000000000000..fb59f6d4352aa --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py @@ -0,0 +1,128 @@ +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, + check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, + marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, + query_marlin_supported_quant_types) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import * + + +class MarlinLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.zero_points: + return False, "Zero points currently not supported by "\ + " MarlinLinearKernel. Will be added when AWQMarlin "\ + "is migrated over to using MPLinearKernel backend" + + quant_types = query_marlin_supported_quant_types(c.zero_points) + if c.weight_type not in quant_types: + return False, f"Quant type ({c.weight_type}) not supported by"\ + f" Marlin, supported types are: {quant_types}" + + if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Marlin, supported group sizes are: "\ + f"{MARLIN_SUPPORTED_GROUP_SIZES}" + + return check_marlin_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1], + c.full_weight_shape[1], + c.group_size) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + + row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0]) + self.is_k_full = marlin_is_k_full(c.act_reordering, row_parallel) + + # Allocate marlin workspace. + self.workspace = marlin_make_workspace(c.partition_weight_shape[1], + device) + + # Default names since marlin requires empty parameters for these, + # TODO: remove this requirement from marlin (allow optional tensors) + if self.w_gidx_name is None: + self.w_gidx_name = "g_idx" + if self.w_zp_name is None: + self.w_zp_name = "w_zp" + + if c.act_reordering: + g_idx, g_idx_sort_indices = marlin_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + if c.zero_points: + pass + # TODO (lucas): add the following when AWQMarlin is migrated over to + # using MPLinearKernel backend + # self._transform_param(layer, self.w_zp_name, lambda x: \ + # marlin_zero_points( + # x, + # size_k=c.partition_weight_shape[0], + # size_n=c.partition_weight_shape[1], + # num_bits=c.weight_type.size_bits)) + else: + setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.gptq_marlin_repack(x.data.contiguous(), + perm=layer.g_idx_sort_indices, + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = marlin_permute_scales(x.data.contiguous(), + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + group_size=c.group_size) + return x + + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) + + # `process_weights_after_loading`` will ensure w_zp and w_gidx are not + # None for marlin + return apply_gptq_marlin_linear( + input=x, + weight=w_q, + weight_scale=w_s, + weight_zp=w_zp, # type: ignore + g_idx=w_gidx, # type: ignore + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=self.workspace, + wtype=c.weight_type, + input_size_per_partition=c.partition_weight_shape[0], + output_size_per_partition=c.partition_weight_shape[1], + is_k_full=self.is_k_full, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py new file mode 100644 index 0000000000000..22172771e5b64 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/__init__.py @@ -0,0 +1,44 @@ +from typing import List, Optional, Type + +from vllm.platforms import current_platform + +from .MacheteLinearKernel import MacheteLinearKernel +from .MarlinLinearKernel import MarlinLinearKernel +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ + MacheteLinearKernel, + MarlinLinearKernel, +] + + +def choose_mp_linear_kernel( + config: MPLinearLayerConfig, + compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: + if compute_capability is None: + if current_platform is None: + raise ValueError("Cannot determine compute capability") + _cc = current_platform.get_device_capability() + compute_capability = _cc[0] * 10 + _cc[1] + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS: + if kernel.get_min_capability() > compute_capability: + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel.get_min_capability()}, current compute capability " + f"is {compute_capability}") + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "WNA16 linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/utils/__init__.py b/vllm/model_executor/layers/quantization/utils/__init__.py index e69de29bb2d1d..e60f0c79ac1f7 100644 --- a/vllm/model_executor/layers/quantization/utils/__init__.py +++ b/vllm/model_executor/layers/quantization/utils/__init__.py @@ -0,0 +1,3 @@ +from .layer_utils import replace_parameter, update_tensor_inplace + +__all__ = ['update_tensor_inplace', 'replace_parameter'] diff --git a/vllm/model_executor/layers/quantization/utils/layer_utils.py b/vllm/model_executor/layers/quantization/utils/layer_utils.py new file mode 100644 index 0000000000000..c38bd8955f457 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/layer_utils.py @@ -0,0 +1,33 @@ +from typing import Union + +import torch + + +def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor): + assert dst.dtype == src.dtype, "Tensors must have the same dtype" + + # update tensor shape and stride + dst.as_strided_(src.shape, src.stride()) + + # If not the same underlying storage move tensor data + if dst.data_ptr() != src.data_ptr(): + dst.copy_(src) + del src + + +# Newly generated tensors need to replace existing tensors that are +# already registered as parameters by vLLM (and won't be freed) +def replace_parameter(mod: torch.nn.Module, name: str, + new: Union[torch.Tensor, torch.nn.Parameter]) -> None: + + old = getattr(mod, name) + if old.dtype == new.dtype and \ + old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): + # If we can just update in-place to avoid re-registering + # can be faster if the underlying storage is the same + update_tensor_inplace(old, new) + else: + # Fallback re-register parameter + if not isinstance(new, torch.nn.Parameter): + new = torch.nn.Parameter(new) + mod.register_parameter(name, torch.nn.Parameter(new)) diff --git a/vllm/model_executor/layers/quantization/utils/machete_utils.py b/vllm/model_executor/layers/quantization/utils/machete_utils.py new file mode 100644 index 0000000000000..18e1332050cdd --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/machete_utils.py @@ -0,0 +1,30 @@ +from typing import List, Optional, Tuple + +import torch + +from vllm.scalar_type import ScalarType, scalar_types + +MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128] +MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128] + + +def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]: + if zero_points: + return [scalar_types.uint4, scalar_types.uint8] + else: + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]: + return [torch.float16, torch.bfloat16] + + +def check_machete_supports_shape(in_features: int, out_featrues: int) \ + -> Tuple[bool, Optional[str]]: + if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: + return False, "Input features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}" + if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: + return False, "Output features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}" + return True, None diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 0ec68ac5b0f21..e83b4eacf8f38 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -118,6 +118,19 @@ def verify_marlin_supports_shape(output_size_per_partition: int, "with --quantization gptq.") +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) + except ValueError as e: + return False, e.__str__() + return True, None + + def marlin_make_workspace(output_size_per_partition: int, device: torch.device) -> torch.Tensor: max_workspace_size = (output_size_per_partition // @@ -146,6 +159,11 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: requires_grad=False) +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + def marlin_sort_g_idx( g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) @@ -221,17 +239,6 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return marlin_zp -# Newly generated tensors need to replace existing tensors that are -# already registered as parameters by vLLM (and won't be freed) -def replace_tensor(layer: torch.nn.Module, name: str, - new_t: torch.Tensor) -> None: - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - def apply_gptq_marlin_linear( input: torch.Tensor, weight: torch.Tensor, diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index c6cfab7892efa..6d84a5ecb3492 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -320,6 +320,68 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): marlin_tile_size=self.marlin_tile_size) +def permute_param_layout_( + param: BasevLLMParameter, + input_dim: int, + output_dim: int, + **kwargs +) -> BasevLLMParameter: + """ + Permute a parameter's layout to the specified input and output dimensions, + useful for forcing the parameter into a known layout, for example, if I need + a packed (quantized) weight matrix to be in the layout + {input_dim = 0, output_dim = 1, packed_dim = 0} + then I can call: + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + to ensure x is in the correct layout (permuting it to the correct layout if + required, asserting if it cannot get it to the correct layout) + """ + + curr_input_dim = getattr(param, "input_dim", None) + curr_output_dim = getattr(param, "output_dim", None) + + if curr_input_dim is None or curr_output_dim is None: + assert param.data.dim() == 2,\ + "permute_param_layout_ only supports 2D parameters where either "\ + "input_dim or output_dim is not set" + + # if one of the dimensions is not set, set it to the opposite of the other + # we can only do this since we asserted the parameter is 2D above + if curr_input_dim is None: + assert curr_output_dim is not None,\ + "either input or output dim must be set" + curr_input_dim = (curr_output_dim + 1) % 2 + if curr_output_dim is None: + assert curr_input_dim is not None,\ + "either input or output dim must be set" + curr_output_dim = (curr_input_dim + 1) % 2 + + # create permutation from the current layout to the layout with + # self.input_dim at input_dim and self.output_dim at output_dim preserving + # other dimensions + perm = [ + i for i in range(param.data.dim()) + if i not in [curr_input_dim, curr_output_dim] + ] + perm.insert(input_dim, curr_input_dim) + perm.insert(output_dim, curr_output_dim) + + if "packed_dim" in kwargs: + assert hasattr(param, "packed_dim") and\ + param.packed_dim == perm[kwargs["packed_dim"]],\ + "permute_param_layout_ currently doesn't support repacking" + + param.data = param.data.permute(*perm) + if hasattr(param, "_input_dim"): + param._input_dim = input_dim + if hasattr(param, "_output_dim"): + param._output_dim = output_dim + if "packed_dim" in kwargs and hasattr(param, "_packed_dim"): + param._packed_dim = kwargs["packed_dim"] + + return param + + def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py index eb491dd1554a8..373151a5311e5 100644 --- a/vllm/scalar_type.py +++ b/vllm/scalar_type.py @@ -27,6 +27,8 @@ class scalar_types: float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value) # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) uint4b8 = ScalarType.uint(4, 8) uint8b128 = ScalarType.uint(8, 128)