From 6ed63f2ebae2d2d6742cc08c855ac8e5b6eb7cd1 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 17 Dec 2024 16:41:45 +0000 Subject: [PATCH] Grouped gemm working Co-authored-by: Lucas Wilkinson Signed-off-by: ElizaWszola --- .../broadcast_load_epilogue_array_c3x.hpp | 464 ++++++++++++++++++ .../epilogue/broadcast_load_epilogue_c3x.hpp | 5 + .../epilogue/scaled_mm_epilogues_c3x.hpp | 64 +++ csrc/ops.h | 12 +- .../cutlass_w8a8/grouped_gemm_test.cu | 224 ++++----- .../cutlass_w8a8/scaled_mm_entry.cu | 26 +- csrc/torch_bindings.cpp | 8 +- tests/kernels/test_cutlass.py | 155 ++---- 8 files changed, 704 insertions(+), 254 deletions(-) create mode 100644 csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp new file mode 100644 index 0000000000000..e652179718c95 --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp @@ -0,0 +1,464 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +// from https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either row/column or scalar broadcasting +// where the tensor being loaded from is always passed in via a device pointer. +// This lets one compiled kernel handle all cases of per-tensor or +// per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graphs +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +// Row vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90RowOrScalarBroadcastArray { + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; + }; + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_row is null. + struct Arguments { + const Element* const* ptr_row_array = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcastArray() { } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage) + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } + + Params params; + Element *smem = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0)); + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, + int group, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , group(group) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; + int group; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if (!params.row_broadcast) { + fill(tSR_rRow, *(params.ptr_row_array[group])); + return; + } + + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if (epi_m == 0) { // Assumes M-major subtile loop + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); + + if (threadIdx.x ==128){ + printf("ROW M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); + } + + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + l, + params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90ColOrScalarBroadcastArray { + static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias + (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_col is null. + struct Arguments { + const Element* const* ptr_col_array = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0)); + } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcastArray() { } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensor&& tCgCol, + RTensor&& tCrCol, + CTensor&& tCcCol, + ProblemShape problem_shape, + int group, + Params const& params + ): + tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + tCcCol(cute::forward(tCcCol)), + m(get<0>(problem_shape)), + group(group), + params(params) {} + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + int m; + int group; + + CUTLASS_DEVICE void + begin() { + Tensor pred = make_tensor(shape(tCgCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tCcCol(i)) < m; + } + + if (!params.col_broadcast) { + fill(tCrCol, *(params.ptr_col_array[group])); + return; + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_if(pred, filter(tCgCol), filter(tCrCol)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + // if (threadIdx.x ==128){ + // printf("COL M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); + // } + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + // Generate an identity tensor matching the shape of the global tensor and + // partition the same way, this will be used to generate the predicate + // tensor for loading + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tCgCol), + cute::move(tCrCol), + cute::move(tCcCol), + args.problem_shape_mnkl, + l, + params + ); + } +}; + +} diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp index 58b1e8ff159fb..9f049efd07b46 100644 --- a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp @@ -422,6 +422,11 @@ struct Sm90ColOrScalarBroadcast { get_consumer_store_callbacks(ConsumerStoreArgs const& args) { auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + if (threadIdx.x ==128){ + printf("M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); + } Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 95764ecddc79f..ad7c45a076e68 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -1,4 +1,5 @@ #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp" /* This file defines custom epilogues for fusing channel scales, token scales, @@ -45,6 +46,16 @@ struct ScaledEpilogueBase { 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + template + using ColOrScalarLoadArray = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoadArray = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<1>, Int<0>>>; + // This utility function constructs the arguments for the load descriptors // from a tensor. It can handle both row and column, as well as row/column or // scalar cases. @@ -72,6 +83,15 @@ struct ScaledEpilogueBase { std::is_same_v>); return Arguments{data_ptr}; } + + template + static auto args_from_tensor(const T* const* data_ptr, bool do_broadcast) { + using Arguments = typename Descriptor::Arguments; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr, do_broadcast}; + } + }; /* @@ -312,4 +332,48 @@ struct ScaledEpilogueBiasAzpToken } }; +/* +TODO document +This is an epilogue with ptr arrays to a_scales and b_scales +*/ +template +struct ScaledEpilogueArray + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoadArray; + using ScaleB = typename SUPER::template RowOrScalarLoadArray; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + using ScaleAArray = typename SUPER::template ColOrScalarLoadArray; + using ScaleBArray = typename SUPER::template RowOrScalarLoadArray; + + static ArgumentType prepare_args(const float* const* a_scales_ptr, + const float* const* b_scales_ptr, + bool a_col_broadcast, + bool b_row_broadcast) { + auto a_args = SUPER::template args_from_tensor(a_scales_ptr, a_col_broadcast); + auto b_args = SUPER::template args_from_tensor(b_scales_ptr, b_row_broadcast); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args}; + } +}; + }; // namespace vllm::c3x \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index fce4346fa4218..b655d3bfab58a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -145,13 +145,11 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); -void cutlass_grouped_mm(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets); +void cutlass_grouped_mm(c10::List const& out_tensors, + c10::List const& a_tensors, + c10::List const& b_tensors, + c10::List const& a_scales, + c10::List const& b_scales); void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, diff --git a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu index 03d23c7739691..c9d299c111304 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu @@ -38,11 +38,6 @@ using namespace cute; namespace { -// A wrapper for the GEMM kernel that is used to guard against compilation on -// architectures that will never use the kernel. The purpose of this is to -// reduce the size of the compiled binary. -// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef -// into code that will be executed on the device where it is defined. template struct enable_sm90_or_later : Kernel { template @@ -54,19 +49,13 @@ struct enable_sm90_or_later : Kernel { }; using ProblemShape = - cutlass::gemm::GroupProblemShape>; // - // per group -using ElementAB_Type = - cutlass::float_e4m3_t; // Element type for A matrix operand -// using ElementB = cutlass::float_e4m3_t; // -// Element type for B matrix operand + cutlass::gemm::GroupProblemShape>; +using ElementAB_Type = cutlass::float_e4m3_t; using ElementC_Type = cutlass::half_t; -// Core kernel configurations -using ElementAccumulator = float; // Element type for internal accumulation -using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that - // supports the intended feature -using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using ElementAccumulator = float; +using ArchTag = cutlass::arch::Sm90; +using OperatorClass = cutlass::arch::OpClassTensorOp; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -154,129 +143,109 @@ void print(const std::tuple& _tup) { } //////////// -template -void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets, - EpilogueArgs&&... epilogue_params) { +template +void cutlass_group_gemm_caller(c10::List const& out_tensors, + c10::List const& a_tensors, + c10::List const& b_tensors, + c10::List const& a_scales, + c10::List const& b_scales) { using ElementAB = typename Gemm::ElementAB; using ElementC = typename Gemm::ElementC; - using ElementAcc = float; - int groups = problem_sizes.size(0); + int groups = (int)a_tensors.size(); + TORCH_CHECK((int)b_tensors.size() == groups, + "Number of B tensors must match number of groups."); + TORCH_CHECK((int)out_tensors.size() == groups, + "Number of output tensors must match number of groups."); + std::vector a_ptrs_host(groups); std::vector b_ptrs_host(groups); std::vector c_ptrs_host(groups); std::vector d_ptrs_host(groups); + std::vector a_scales_ptrs_host(groups); + std::vector b_scales_ptrs_host(groups); + + std::vector problem_sizes_host; + problem_sizes_host.reserve(groups); for (int g = 0; g < groups; ++g) { - a_ptrs_host.at(g) = static_cast(a.data_ptr()) + - a_offsets[g].item(); - b_ptrs_host.at(g) = static_cast(b.data_ptr()) + - b_offsets[g].item(); - c_ptrs_host.at(g) = static_cast(out.data_ptr()) + - out_offsets[g].item(); - d_ptrs_host.at(g) = - static_cast(out.data_ptr()) + out_offsets[g].item(); - printf("off: %d %d %d\n", a_offsets[g].item(), - b_offsets[g].item(), out_offsets[g].item()); + a_ptrs_host[g] = + reinterpret_cast(a_tensors[g].data_ptr()); + b_ptrs_host[g] = + reinterpret_cast(b_tensors[g].data_ptr()); + c_ptrs_host[g] = + reinterpret_cast(out_tensors[g].data_ptr()); + d_ptrs_host[g] = reinterpret_cast(out_tensors[g].data_ptr()); + a_scales_ptrs_host[g] = + reinterpret_cast(a_scales[g].data_ptr()); + b_scales_ptrs_host[g] = + reinterpret_cast(b_scales[g].data_ptr()); + + int64_t m = a_tensors[g].size(0); + int64_t k = a_tensors[g].size(1); + + int64_t k_b = b_tensors[g].size(0); + int64_t n = b_tensors[g].size(1); + + TORCH_CHECK(k == k_b, "Dimension mismatch between A and B: A has k=", k, + " while B has k=", k_b); + + // Optionally, verify output shape matches (m,n) + TORCH_CHECK(out_tensors[g].size(0) == m && out_tensors[g].size(1) == n, + "Output tensor shape does not match m,n from A,B: ", "Got ", + out_tensors[g].sizes(), " expected (", m, ", ", n, ")"); + + problem_sizes_host.push_back({(int)m, (int)n, (int)k}); } using GemmKernel = typename Gemm::GemmKernel; + using StrideA = Stride, Int<0>>; + using StrideB = Stride, Int<0>>; + using StrideC = typename GemmKernel::InternalStrideC; - // using StrideA = typename GemmKernel::InternalStrideA; - // using StrideB = typename GemmKernel::InternalStrideB; - // using StrideC = typename GemmKernel::InternalStrideC; - // // using StrideD = typename GemmKernel::InternalStrideD; + std::vector a_stride_host(groups); + std::vector b_stride_host(groups); + std::vector c_stride_host(groups); - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); + for (int32_t g = 0; g < groups; ++g) { + int64_t lda = a_tensors[g].stride(0); // row-major (m x k) + int64_t ldb = b_tensors[g].stride(1); // column-major (k x n) + int64_t ldc = out_tensors[g].stride(0); // row-major (m x n) - using StrideA = Stride, Int<0>>; - using StrideB = Stride, Int<0>>; - using StrideC = - typename GemmKernel::InternalStrideC; // typename Gemm::StrideC; - - // StrideA a_stride{lda, Int<1>{}, Int<0>{}}; - // StrideB b_stride{ldb, Int<1>{}, Int<0>{}}; - // StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - std::vector a_stride_host(groups, StrideA{lda, Int<1>{}, Int<0>{}}); - std::vector b_stride_host(groups, StrideB{ldb, Int<1>{}, Int<0>{}}); - std::vector c_stride_host(groups, StrideC{ldc, Int<1>{}, Int<0>{}}); - - printf("a: "); - print(a_stride_host[0]); - printf("\nb: "); - print(b_stride_host[0]); - printf("\nc: "); - print(c_stride_host[0]); - printf("\n"); - - // for (int g = 0; g < groups; ++g) { - // int32_t m = problem_sizes[g][0].item(); - // int32_t n = problem_sizes[g][1].item(); - // int32_t k = problem_sizes[g][2].item(); - // a_stride_host[g] = StrideA{k, cute::Int<1>{}, cute::Int<0>{}}; // m x k, - // // row - // b_stride_host[g] = StrideB{k, cute::Int<1>{}, cute::Int<0>{}}; // k x n, - // // col - // c_stride_host[g] = StrideC{n, cute::Int<1>{}, cute::Int<0>{}}; // m x n, - // // row - // } + a_stride_host[g] = StrideA{lda, Int<1>{}, Int<0>{}}; + b_stride_host[g] = StrideB{ldb, Int<1>{}, Int<0>{}}; + c_stride_host[g] = StrideC{ldc, Int<1>{}, Int<0>{}}; + } cutlass::KernelHardwareInfo hw_info; - // Change device_id to another value if you are running on a machine with - // multiple GPUs and wish to use a GPU other than that with device ID 0. hw_info.device_id = 0; hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count( hw_info.device_id); - using SingleProblemShape = typename ProblemShape::UnderlyingProblemShape; - - std::vector problem_sizes_host; - problem_sizes_host.reserve(groups); - for (int32_t g = 0; g < groups; ++g) { - int32_t m = problem_sizes[g][0].item(); - int32_t n = problem_sizes[g][1].item(); - int32_t k = problem_sizes[g][2].item(); - problem_sizes_host.push_back({m, n, k}); - printf("mnk: %d, %d, %d\n", m, n, k); - } - - auto problem_sizes_ptr = - make_device_ptr(problem_sizes_host); + auto problem_sizes_ptr = make_device_ptr(problem_sizes_host); ProblemShape prob_shape{groups, problem_sizes_ptr.get(), problem_sizes_host.data()}; - // ElementAB* a_host_print; - // int numel = a.numel(); - // cudaMalloc(&a_host_print, groups * sizeof(ElementAB)); - // cudaMemcpy(a_host_print, static_cast(a.data_ptr()), numel* - // sizeof(ElementAB), cudaMemcpyDeviceToHost); - // cudaMemcpy(static_cast(a.data_ptr()), a_host_print, numel* - // sizeof(ElementAB), cudaMemcpyHostToDevice); cudaFree(a_host_print); + auto a_ptrs_ptr = make_device_ptr(a_ptrs_host); + auto b_ptrs_ptr = make_device_ptr(b_ptrs_host); + auto c_ptrs_ptr = make_device_ptr(c_ptrs_host); + auto d_ptrs_ptr = make_device_ptr(d_ptrs_host); - auto a_ptrs_ptr = make_device_ptr(a_ptrs_host); - auto b_ptrs_ptr = make_device_ptr(b_ptrs_host); - auto c_ptrs_ptr = make_device_ptr(c_ptrs_host); - auto d_ptrs_ptr = make_device_ptr(d_ptrs_host); + auto a_scales_ptrs_ptr = make_device_ptr(a_scales_ptrs_host); + auto b_scales_ptrs_ptr = make_device_ptr(b_scales_ptrs_host); - auto a_stride_ptr = make_device_ptr(a_stride_host); - auto b_stride_ptr = make_device_ptr(b_stride_host); - auto c_stride_ptr = make_device_ptr(c_stride_host); + auto a_stride_ptr = make_device_ptr(a_stride_host); + auto b_stride_ptr = make_device_ptr(b_stride_host); + auto c_stride_ptr = make_device_ptr(c_stride_host); typename GemmKernel::MainloopArguments mainloop_args{ a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), b_stride_ptr.get()}; typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( - std::forward(epilogue_params)...), + a_scales_ptrs_ptr.get(), b_scales_ptrs_ptr.get(), + a_scales[0].numel() != 1, b_scales[0].numel() != 1), c_ptrs_ptr.get(), c_stride_ptr.get(), d_ptrs_ptr.get(), c_stride_ptr.get()}; @@ -284,30 +253,26 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, epilogue_args, hw_info}; - // Launch the CUTLASS GEMM kernel. using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; GemmOp gemm_op; + // std::cout << "gemm_op.can_implement(args): " + // << (int)gemm_op.can_implement(args) << std::endl; CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors[0].device()); auto workspace = torch::empty(workspace_size, workspace_options); - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - + auto stream = at::cuda::getCurrentCUDAStream(a_tensors[0].device().index()); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); } -// typedef InType = cutlass::float_e4m3_t; -// typedef OutType = torch::half; - template typename Epilogue> struct sm90_fp8_config_default { - // M in (128, inf) - static_assert(std::is_same()); + static_assert(std::is_same_v); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = @@ -354,18 +319,23 @@ struct sm90_fp8_config_M64 { } // namespace -// TODO hardcode types here? -void cutlass_grouped_mm_sm90( - torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, - torch::Tensor const& a_scales, torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, torch::Tensor const& b_offsets) { - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - // int32_t m = a.size(1); +void cutlass_grouped_mm_sm90(c10::List const& out_tensors, + c10::List const& a_tensors, + c10::List const& b_tensors, + c10::List const& a_scales, + c10::List const& b_scales) { + TORCH_CHECK(a_tensors.size() > 0, "No input A tensors provided."); + TORCH_CHECK(b_tensors.size() > 0, "No input B tensors provided."); + TORCH_CHECK(out_tensors.size() > 0, "No output tensors provided."); + + TORCH_CHECK(a_tensors[0].dtype() == torch::kFloat8_e4m3fn, + "A tensors must be of type float8_e4m3fn."); + TORCH_CHECK(b_tensors[0].dtype() == torch::kFloat8_e4m3fn, + "B tensors must be of type float8_e4m3fn."); using Cutlass3xGemmDefault = typename sm90_fp8_config_default< - ElementAB_Type, ElementC_Type, vllm::c3x::ScaledEpilogue>::Cutlass3xGemm; + ElementAB_Type, ElementC_Type, + vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; // using Cutlass3xGemmM64 = // typename sm90_fp8_config_M64::Cutlass3xGemm; @@ -388,7 +358,5 @@ void cutlass_grouped_mm_sm90( // } else { // // m in (128, inf) cutlass_group_gemm_caller( - out, a, b, problem_sizes, out_offsets, a_offsets, b_offsets, a_scales, - b_scales); - // } + out_tensors, a_tensors, b_tensors, a_scales, b_scales); } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 961437893dee0..eb5d09a6de7ba 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -28,11 +28,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); -void cutlass_grouped_mm_sm90( - torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, - torch::Tensor const& a_scales, torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, torch::Tensor const& b_offsets); +void cutlass_grouped_mm_sm90(c10::List const& out_tensors, + c10::List const& a_tensors, + c10::List const& b_tensors, + c10::List const& a_scales, + c10::List const& b_scales); #endif @@ -158,15 +158,13 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, version_num); } -void cutlass_grouped_mm(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets) { - cutlass_grouped_mm_sm90(out, a, b, a_scales, b_scales, problem_sizes, - out_offsets, a_offsets, b_offsets); +void cutlass_grouped_mm(c10::List const& out_tensors, + c10::List const& a_tensors, + c10::List const& b_tensors, + c10::List const& a_scales, + c10::List const& b_scales) { + cutlass_grouped_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, + b_scales); } void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index a10c661b22a6a..22a1a1a4ae080 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -313,10 +313,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // CUTLASS w8a8 grouped GEMM // TODO complete this ops.def( - "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, " - " Tensor b_scales, Tensor problem_sizes, " - " Tensor out_offsets, Tensor a_offsets, " - " Tensor b_offsets) -> ()"); + "cutlass_grouped_mm(Tensor![] out_tensors," + " Tensor[] a_tensors," + " Tensor[] b_tensors, Tensor[] a_scales, " + " Tensor[] b_scales) -> ()"); ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); // Mamba selective scan kernel diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 1532feba47d6a..4c909669aa5d3 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -457,116 +457,69 @@ def test_cutlass_support_opcheck(): opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) -# TODO fix scales -@pytest.mark.parametrize("m,n,k", [(2048, 2048, 2048)]) -@pytest.mark.parametrize("num_groups", [1, 4, 10]) +@pytest.mark.parametrize("num_groups", [8]) @pytest.mark.parametrize("per_act_token", [True, False]) # [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) # [True, False]) @pytest.mark.parametrize("use_bias", [False]) # [True, False]) @pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") -def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, - per_act_token: bool, per_out_ch: bool, - use_bias: bool): +def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, + per_out_ch: bool, use_bias: bool): - # Test for a cutlass kernel with per-token activation quantization - # and per-output channel weight quantization. + # Device and dtype setup device = "cuda" out_dtype = torch.half - alignment = 16 # 128 // 8 - problem_sizes = torch.empty((num_groups, 3), device="cpu") - offsets_a = torch.empty((num_groups), device="cpu", dtype=torch.int32) - offsets_b = torch.empty((num_groups), device="cpu", dtype=torch.int32) - offsets_c = torch.empty((num_groups), device="cpu", dtype=torch.int32) - tot_a = 0 - tot_b = 0 - tot_c = 0 - m = alignment * random.randint(1, 64) - n = alignment * random.randint(1, 64) - k = alignment * random.randint(1, 64) - for g in range(num_groups): - tot_a += m - tot_b += n - tot_c += m - print(m, n, k) - offsets_a[g] = g * m * k - offsets_b[g] = g * k * n - offsets_c[g] = g * m * n - problem_sizes[g][0] = m - problem_sizes[g][1] = n - problem_sizes[g][2] = k - - a = to_fp8(torch.randn((tot_a, k), device=device)) - - b_float = torch.randn((tot_b, k), device=device) - # for g in range(num_groups): - # b_float[g * k:(g + 1) * k] = torch.full((k, n), g + 1) - # print(b_float) - - b = to_fp8(b_float.t()) - c = torch.zeros((tot_c, n), device=device).to(out_dtype) - baseline = torch.zeros((tot_c, n), device=device).to(out_dtype) - - # print(a) - # print(b) - - # print(offsets_a) - # print(offsets_b) - # print(offsets_c) - # print(tot_a, tot_b, tot_c) - - # print(a.stride(), b.stride(), c.stride()) - - scale_a = (torch.randn(((m, 1) if per_act_token else (1, 1)), - device=device, - dtype=torch.float32)) - scale_b = (torch.randn(((1, n) if per_out_ch else (1, 1)), - device=device, - dtype=torch.float32)) - - # if use_bias: - # bias = torch.rand((n, 1), device=device, dtype=out_dtype) * 10 - # else: - # bias = None - - # print(a) - - # TODO strides we can get later the same way as in scaled_mm_c3x.cu - torch.ops._C.cutlass_grouped_mm(c, a, b, scale_a, scale_b, problem_sizes, - offsets_c, offsets_a, offsets_b) - # baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, None) - - # print(a.dtype) - # print(a) - - # torch.set_printoptions(profile='full') - # # print(c[2*m:3*m]) - # print(torch.max(c, dim=1)) - # print(torch.max(c, dim=0)) - # print(c) + # Create separate A, B, C tensors for each group + a_tensors = [] + b_tensors = [] + a_scales_tensors = [] + b_scales_tensors = [] + out_tensors = [] + baseline_tensors = [] + alignment = 16 # 128 // 8 + # For variation, each group g has dimensions + # (m_g = m/(g+1), n_g = n/(g+1), k_g = k/(g+1)) for g in range(num_groups): - print(a[g * m:(g + 1) * m].shape, b[:, g * n:(g + 1) * n].shape) - baseline[g * m:(g + 1) * m] = baseline_scaled_mm( - a[g * m:(g + 1) * m], - b[:, g * n:(g + 1) * n], - # scale_a[g * m:(g + 1) * m] if per_act_token else scale_a[g], - # # scale_b[:, g * n:(g + 1) * n] if per_out_ch else scale_b[:, g], - # scale_b[g], - scale_a, - scale_b, - out_dtype, - None) - print(baseline[g * m:(g + 1) * m]) - print(c[g * m:(g + 1) * m]) + m_g = alignment * random.randint(1, 64) + n_g = alignment * random.randint(1, 64) + k_g = alignment * random.randint(1, 64) + + m_a_scales = m_g if per_act_token else 1 + n_b_scales = n_g if per_out_ch else 1 + + print(m_g, n_g, k_g) + + # Create group-specific A and B (FP8) and output (FP16/FP32) + a_g = to_fp8(torch.randn((m_g, k_g), device=device)) + b_g = to_fp8(torch.randn((n_g, k_g), device=device).t()) + c_g = torch.zeros((m_g, n_g), device=device, dtype=out_dtype) + # Set up A/B scales + scale_a = torch.randn((m_a_scales, 1), + device=device, + dtype=torch.float32) + scale_b = torch.randn((1, n_b_scales), + device=device, + dtype=torch.float32) + + a_tensors.append(a_g) + b_tensors.append(b_g) + out_tensors.append(c_g) + a_scales_tensors.append(scale_a) + b_scales_tensors.append(scale_b) + + # Compute baseline result for this group + baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, + None) + baseline_tensors.append(baseline_g) + + torch.ops._C.cutlass_grouped_mm(out_tensors, a_tensors, b_tensors, + a_scales_tensors, b_scales_tensors) + + # Validate each group's result against the baseline + for c_g, baseline_g in zip(out_tensors, baseline_tensors): + print(baseline_g) + print(c_g) print("*") - - # baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, None) - # print(baseline) - # print(c) - - torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-2) - - # opcheck(torch.ops._C.cutlass_scaled_mm, - # (out, a, b, scale_a, scale_b, bias)) + torch.testing.assert_close(c_g, baseline_g, rtol=1e-2, atol=5e-2)