diff --git a/bitsandbytes/backends/npu.py b/bitsandbytes/backends/npu.py index d22fe04e8..a268da909 100644 --- a/bitsandbytes/backends/npu.py +++ b/bitsandbytes/backends/npu.py @@ -11,6 +11,7 @@ from bitsandbytes.cextension import lib from bitsandbytes.functional import ( + COOSparseTensor, get_4bit_type, get_ptr, ) @@ -28,6 +29,43 @@ def assert_on_npu(tensors): return True +def coo_zeros(rows, cols, rowidx, colidx, values, nnz, device, dtype=torch.half): + rowidx = rowidx.to(torch.int32) + colidx = colidx.to(torch.int32) + values = values.to(device).to(dtype) + return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) + + +def row_col_stats(A, threshold): + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + rows = A.shape[0] + + row_max = torch.zeros(rows, dtype=torch.float32, device="npu") + col_max = torch.zeros(cols, dtype=torch.float32, device="npu") + outlier_num = torch.zeros(1, dtype=torch.int32, device="npu") + lib.cget_col_row_stats( + get_ptr(A), + get_ptr(row_max), + get_ptr(col_max), + get_ptr(outlier_num), + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + torch.npu.current_stream() + ) + return row_max, col_max, outlier_num + + +class Int8AB: + def __init__(self, A: torch.Tensor, B: torch.Tensor): + self.A = A + self.B = B + self.device = A.device + + class NPUBackend(Backend): def int8_double_quant( self, @@ -38,7 +76,53 @@ def int8_double_quant( out_row: Optional[torch.Tensor] = None, threshold=0.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - raise NotImplementedError + past_device = None + device = A.device + assert A.dtype == torch.half + assert device.type == "npu" + + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + rows = A.shape[0] + + if past_device != str(A.device): + torch.npu.set_device(A.device) # reset context + past_device = str(A.device) + + row_stats, col_stats, cnt_npu = row_col_stats(A, threshold) + + quant_row = torch.empty((rows, cols), dtype=torch.int8, device=device) + quant_col = torch.empty((rows, cols), dtype=torch.int8, device=device) + outliers_row_idx = torch.zeros(rows, dtype=torch.int32, device=device) + outliers_col_idx = torch.zeros(40 * cols, dtype=torch.int32, device=device) - 1 + outliers_value = torch.empty(0, dtype=torch.float16, device=device) + + lib.cdouble_rowcol_quant( + get_ptr(A), + get_ptr(row_stats), + get_ptr(col_stats), + get_ptr(quant_row), + get_ptr(quant_col), + get_ptr(outliers_row_idx), + get_ptr(outliers_col_idx), + get_ptr(outliers_value), + ct.c_int(cols), + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + torch.npu.current_stream() + ) + + colidx_tmp = torch.unique(outliers_col_idx) + colidx = colidx_tmp[colidx_tmp != -1] + + coo_tensor = None + if threshold != 0.0: + coo_tensor = coo_zeros(rows, cols, outliers_row_idx, colidx, outliers_value, cnt_npu, device, dtype=torch.half) + + return quant_row, quant_col, row_stats, col_stats, coo_tensor def int8_vectorwise_dequant(self, A, stats): return super().int8_vectorwise_dequant(A, stats) @@ -48,7 +132,35 @@ def int8_vectorwise_quant( A: torch.Tensor, threshold=0.0, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - raise NotImplementedError + device = A.device + assert A.dtype == torch.half + assert device.type == "npu" + + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + rows = A.shape[0] + + A_no_threshold = None + if threshold > 0.0: + zero = torch.tensor(0.0, dtype=torch.half, device=device) + A_no_threshold = torch.where(A.view(rows, cols).abs() < threshold, A.view(rows, cols), zero) + row_stats = torch.amax(A_no_threshold.abs(), dim=1, keepdim=True).to(device) + out_row = torch.round(A_no_threshold * 127.0 / row_stats).to(torch.int8) + else: + row_stats = torch.amax(A.view(rows, cols).abs(), dim=1, keepdim=True).to(device) + out_row = torch.round(A * 127.0 / row_stats).to(torch.int8) + + outlier_cols = None + if threshold > 0.0: + # TODO we could improve perf of this + outliers = A.abs() >= threshold + + if outliers.any(): + outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) + + return out_row, row_stats, outlier_cols def transform( self, @@ -69,7 +181,7 @@ def int8_linear_matmul( out: Optional[torch.Tensor] = None, dtype=torch.int32, ) -> torch.Tensor: - raise NotImplementedError + return Int8AB(A, B) def int8_mm_dequant( self, @@ -79,7 +191,15 @@ def int8_mm_dequant( out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - raise NotImplementedError + A, B = A.A, A.B + out = torch_npu.npu_quant_matmul( + A, + B.t(), + scale=col_stats.float() / 127.0, + pertoken_scale=row_stats.float().view(-1) / 127.0, + output_dtype=torch.float16 + ) + return out def extract_outliers( self, @@ -106,6 +226,10 @@ def quantize_4bit( if blocksize is None: blocksize = 128 + total_blocks = A.numel() // blocksize + chunks = 8 if A.numel() > 2048 * 2048 else 1 + chunksize = (total_blocks + chunks - 1) // chunks + prev_device = torch.npu.current_device() torch.npu.set_device(A.device) if A.dtype in [torch.float32, torch.float16, torch.bfloat16]: @@ -128,12 +252,27 @@ def quantize_4bit( 1.0, ] data = torch.tensor(data, device="npu", dtype=torch.float32).view(1, -1) - absmax = A.view(-1, blocksize).abs().max(dim=1, keepdim=True).values - a = A.view(-1, blocksize) / absmax.float() - diff = torch.abs(a.unsqueeze(-1) - data) - out = (torch.argmin(diff, dim=-1) + 8) % 16 - out = out.reshape(-1, 2) - out = (out[:, 0] + out[:, 1] * 16).to(torch.uint8) + chunks_absmax = [] + chunks_out = [] + + for i in range(chunks): + start = i * chunksize * blocksize + end = min((i + 1) * chunksize * blocksize, A.numel()) + chunk_data = A.view(-1)[start:end].view(-1, blocksize) + + absmax = chunk_data.abs().max(dim=1, keepdim=True).values + chunks_absmax.append(absmax) + + a = chunk_data / absmax.float() + diff = torch.abs(a.unsqueeze(-1) - data) + out = (torch.argmin(diff, dim=-1) + 8) % 16 + + out = out.reshape(-1, 2) + out = (out[:, 0] + out[:, 1] * 16).to(torch.uint8) + chunks_out.append(out) + + absmax = torch.cat(chunks_absmax, dim=0) + out = torch.cat(chunks_out, dim=0) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") assert_on_npu([A, absmax, out]) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index a48a58414..09151951f 100755 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import copy +import importlib from typing import Any, Dict, Optional, TypeVar, Union, overload import warnings @@ -320,9 +321,6 @@ def cpu(self, non_blocking: bool = False): return self.to(device="cpu", non_blocking=non_blocking) def npu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): - # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). - if isinstance(device, int): - device = f"npu:{device}" return self.to(device="npu" if device is None else device, non_blocking=non_blocking) def xpu(self, non_blocking: bool = False): @@ -345,7 +343,10 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device is not None and device.type in ["cuda", "cpu", "npu", "xpu"] and not self.bnb_quantized: + # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). + if importlib.util.find_spec("torch_npu") and device.type == "cuda" and not self.bnb_quantized: + return self._quantize(f"npu:{device}" if isinstance(device, int) else str(device).replace("cuda", "npu")) + elif device is not None and device.type in ["cuda", "cpu", "npu", "xpu"] and not self.bnb_quantized: return self._quantize(device) else: if self.quant_state is not None: @@ -677,6 +678,19 @@ def xpu(self, device): self.SCB = SCB return self + def npu(self, device): + # we store the 8-bit rows-major weight + B = self.data.contiguous().to(torch.float16).npu(device) + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) + if CBt is not None: + del CBt + if SCBt is not None: + del SCBt + self.data = CB + self.CB = CB + self.SCB = SCB + return self + @overload def to( self: T, @@ -695,7 +709,10 @@ def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) if device is not None: - if device.type == "cuda" and self.data.device.type == "cpu": + # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). + if importlib.util.find_spec("torch_npu") and device.type == "cuda": + return self.npu(f"npu:{device}" if isinstance(device, int) else str(device).replace("cuda", "npu")) + elif device.type == "cuda" and self.data.device.type == "cpu": return self.cuda(device) elif device.type == "cpu": if self.data.dtype == torch.int8: diff --git a/csrc/npu_kernels.cpp b/csrc/npu_kernels.cpp index c70e71681..6df262e11 100644 --- a/csrc/npu_kernels.cpp +++ b/csrc/npu_kernels.cpp @@ -5,12 +5,6 @@ using namespace AscendC; constexpr int32_t BUFFER_NUM = 1; -constexpr half Q_COFF_0 = -0.377685546875; -constexpr half Q_COFF_1 = -3.193359375; -constexpr half Q_COFF_2 = 0.583984375; -constexpr half Q_COFF_3 = 6.02734375; -constexpr half Q_COFF_4 = 1.9560546875; -constexpr half Q_COFF_5 = 7.08984375; #define CEIL32(num) (((num) + 32 - 1) / 32 * 32) #define CEIL_BASE(num, base) (((num) + (base) - 1) / (base) * (base)) @@ -200,6 +194,1080 @@ class KernelDequantizeBlockwiseNf4 { }; +namespace row_col_quant_kernel { + constexpr uint32_t DEFAULT_MIN_BLOCK_SIZE = 32; + + struct CurrentTileOffset { + uint32_t rowIndex = 0; // Uint: element + uint32_t colIndex = 0; // Uint: element + }; + + struct CopyParam { + uint16_t blockCount; + uint16_t blockLen; + uint16_t blockLen_int8; + uint16_t stride; + uint16_t stride_int8; + }; + + // tiling for RowColQuant Vector on one VectorCore + struct RowColQuantTilingKernel { + uint32_t coreIdx = 0; // vector core idx + uint32_t is32BAligned = 1; + uint32_t usedCoreNum = 1; + uint64_t totalBlockLen = 0; + uint32_t inputDataType = 1; + uint64_t colLen = 0; + uint64_t rowLen = 0; + + uint32_t baseRowLen = 0; // for one tile in one core, Unit:element + uint32_t baseColLen = 0; // for one tile in one core, Unit:element + uint32_t tailRowLen = 0; // number of tail row in one core, Unit:element + uint32_t tailColLen = 0; // number of column in one core, Unit:element + + uint32_t rowAlignedLen = 0; + + uint32_t tileLength = 0; // baseRowLen * baseColLen + + uint64_t rowTileNum = 0; + uint64_t colTileNum = 0; + uint64_t totalTileNum = 0; + + uint64_t baseRowTileNum = 0; + uint64_t baseColTileNum = 0; + + uint64_t baseRowBaseColCalLen = 0; + uint64_t baseRowTailColCalLen = 0; + uint64_t tailRowBaseColCalLen = 0; + uint64_t tailRowTailColCalLen = 0; + CopyParam baseRowBaseColCopyParam; + CopyParam baseRowTailColCopyParam; + CopyParam tailRowBaseColCopyParam; + CopyParam tailRowTailColCopyParam; + + float threshold = 0.0f; + uint32_t outliersNum = 0; + uint32_t isColQuant = 0; + uint32_t isOutlierIndex = 0; + + uint32_t curCalLen = 0; // curCalRowLen * curCalColLen + uint64_t curCalAlignedRowLen = 0; + uint32_t curCalRowLen = 0; // row length of current tile + uint32_t curCalColLen = 0; // aligned col length of current tile. ALIGNUP(curColLen, alignedLen) Uint: element + uint32_t curColLen = 0; // col length of current tile. Uint: element + uint64_t gmOffset = 0; + float curCalRowLen_float = 0.0; + CopyParam *curTileCopyParam = nullptr; + CurrentTileOffset curTileOffset; + + uint32_t usedCoreNumForOutlier = 0; + uint32_t baseCoreNumForOutlier = 0; + OutlierTilingParam baseCoreParam; + OutlierTilingParam tailCoreParam; + OutlierTilingParam curCoreParam; + uint32_t copyInOffset; + + // calc tiling data + __aicore__ void GetTilingAndOffset(GM_ADDR tilingGm_, uint32_t inputDTypeLen_) + { + auto tempTilingGm = (__gm__ RowColQuantTilingData *)tilingGm_; + inputDataType = 2; + + // input scalar parameters + outliersNum = tempTilingGm->outliersNum; + usedCoreNum = tempTilingGm->usedCoreNum; + is32BAligned = tempTilingGm->is32BAligned; + rowLen = tempTilingGm->rowLen; + colLen = tempTilingGm->colLen; + totalBlockLen = rowLen * colLen; + + baseRowLen = tempTilingGm->baseRowLen; + baseColLen = tempTilingGm->baseColLen; + + // input scalar parameters + threshold = tempTilingGm->threshold; + outliersNum = tempTilingGm->outliersNum; + isColQuant = tempTilingGm->isColQuant; + isOutlierIndex = tempTilingGm->isOutlierIndex; + + usedCoreNumForOutlier = tempTilingGm->usedCoreNumForOutlier; + baseCoreNumForOutlier = tempTilingGm->baseCoreNumForOutlier; + + baseCoreParam.colLen = tempTilingGm->baseCoreParam.colLen; + baseCoreParam.loopNum = tempTilingGm->baseCoreParam.loopNum; + baseCoreParam.tileCol = tempTilingGm->baseCoreParam.tileCol; + baseCoreParam.isTailExist = tempTilingGm->baseCoreParam.isTailExist; + baseCoreParam.tailCol = tempTilingGm->baseCoreParam.tailCol; + + tailCoreParam.colLen = tempTilingGm->tailCoreParam.colLen; + tailCoreParam.loopNum = tempTilingGm->tailCoreParam.loopNum; + tailCoreParam.tileCol = tempTilingGm->tailCoreParam.tileCol; + tailCoreParam.isTailExist = tempTilingGm->tailCoreParam.isTailExist; + tailCoreParam.tailCol = tempTilingGm->tailCoreParam.tailCol; + + auto alignedLen = DEFAULT_MIN_BLOCK_SIZE / 2; + tileLength = (is32BAligned == 1) ? (baseRowLen * baseColLen) : baseRowLen * ALIGNUP(baseColLen, alignedLen); + + rowAlignedLen = ALIGNUP(baseRowLen, alignedLen); + + baseRowTileNum = rowLen / baseRowLen; + baseColTileNum = colLen / baseColLen; + tailRowLen = rowLen % baseRowLen; + tailColLen = colLen % baseColLen; + rowTileNum = (tailRowLen > 0) ? (baseRowTileNum + 1) : baseRowTileNum; + colTileNum = (tailColLen > 0) ? (baseColTileNum + 1) : baseColTileNum; + totalTileNum = rowTileNum * colTileNum; + + coreIdx = AscendC::GetBlockIdx(); + if (coreIdx < usedCoreNum) { + CalcTileCopyParams(inputDataType); + } + if (coreIdx < usedCoreNumForOutlier) { + CalcOutlierParam(); + } + } + + __aicore__ inline void CalcOutlierParam() + { + if (coreIdx < baseCoreNumForOutlier) { + curCoreParam = baseCoreParam; + copyInOffset = coreIdx * baseCoreParam.colLen; + } else { + curCoreParam = tailCoreParam; + copyInOffset = colLen - tailCoreParam.colLen; + } + } + + __aicore__ inline void CalcOneTileCopyParam( + uint64_t calRowLen, uint64_t calColLen, uint32_t inputDTypeLen, CopyParam ©Param) + { + uint16_t blockUnit = (is32BAligned == 1) ? DEFAULT_MIN_BLOCK_SIZE : 1; + copyParam.blockCount = calRowLen; + copyParam.blockLen = calColLen * inputDTypeLen / blockUnit; + copyParam.blockLen_int8 = calColLen / blockUnit; + copyParam.stride = (calRowLen == 1) ? 0 : ((colLen - calColLen) * inputDTypeLen / blockUnit); + copyParam.stride_int8 = (calRowLen == 1) ? 0 : ((colLen - calColLen) / blockUnit); + } + + __aicore__ inline void CalcTileCopyParams(uint32_t inputDTypeLen) + { + // zone1:baseRow-baseCol zone2:baseRow-tailCol zone3:tailRow-baseCol zone4:tailRow-tailCol + // base row , base col + bool aligned = (is32BAligned == 1); + auto alignedLen = DEFAULT_MIN_BLOCK_SIZE / inputDTypeLen; + baseRowBaseColCalLen = aligned ? (baseRowLen * baseColLen) : (baseRowLen * ALIGNUP(baseColLen, alignedLen)); + CalcOneTileCopyParam(baseRowLen, baseColLen, inputDTypeLen, baseRowBaseColCopyParam); + + // base row , tail col + baseRowTailColCalLen = aligned ? (baseRowLen * tailColLen) : baseRowLen * ALIGNUP(tailColLen, alignedLen); + CalcOneTileCopyParam(baseRowLen, tailColLen, inputDTypeLen, baseRowTailColCopyParam); + + // tail row , base col + tailRowBaseColCalLen = aligned ? (tailRowLen * baseColLen) : tailRowLen * ALIGNUP(baseColLen, alignedLen); + CalcOneTileCopyParam(tailRowLen, baseColLen, inputDTypeLen, tailRowBaseColCopyParam); + + // tail row , tail col + tailRowTailColCalLen = aligned ? (tailRowLen * tailColLen) : tailRowLen * ALIGNUP(tailColLen, alignedLen); + CalcOneTileCopyParam(tailRowLen, tailColLen, inputDTypeLen, tailRowTailColCopyParam); + } + + __aicore__ inline void CalcOneTileOffsetParam(uint64_t gmRowOffset, uint64_t rowIdx, uint64_t colIdx) + { + curTileOffset.rowIndex = rowIdx * baseRowLen; + curTileOffset.colIndex = colIdx * baseColLen; + gmOffset = gmRowOffset * colLen + colIdx * baseColLen; + } + + __aicore__ inline void SetCurTileParam( + uint64_t calTileLen_, uint64_t calRowLen_, uint64_t calColLen_, CopyParam *copyParam) + { + bool aligned = (is32BAligned == 1); + auto alignedLen = DEFAULT_MIN_BLOCK_SIZE / inputDataType; + curCalLen = calTileLen_; + curCalRowLen = calRowLen_; + curCalColLen = calColLen_; + curTileCopyParam = copyParam; + } + + __aicore__ inline void CalcOneTileParam(uint64_t tileIdx) + { + uint64_t rowTileIdx = tileIdx / colTileNum; + uint64_t colTileIdx = tileIdx % colTileNum; + CalcOneTileOffsetParam(rowTileIdx * baseRowLen, rowTileIdx, colTileIdx); + if (rowTileIdx < baseRowTileNum) { + if (colTileIdx < baseColTileNum) { + // base row, base col + SetCurTileParam(baseRowBaseColCalLen, baseRowLen, baseColLen, &baseRowBaseColCopyParam); + } else { + // base row, tail col + SetCurTileParam(baseRowTailColCalLen, baseRowLen, tailColLen, &baseRowTailColCopyParam); + } + } else { + if (colTileIdx < baseColTileNum) { + // tail row, base col + SetCurTileParam(tailRowBaseColCalLen, tailRowLen, baseColLen, &tailRowBaseColCopyParam); + } else { + // tail row, tail col + SetCurTileParam(tailRowTailColCalLen, tailRowLen, tailColLen, &tailRowTailColCopyParam); + } + } + } + }; + +#define ROW_COL_QUANT_PROCESS_TILE(gmOffset, copyParam, calLen) \ + CopyIn(gmOffset, copyParam); \ + this->Compute(calLen); \ + CopyOut(gmOffset, copyParam); + +#define ROW_COL_QUANT_PROCESS(kernelTiling) \ + do { \ + uint64_t blockNum = GetBlockNum(); \ + uint64_t baseTileNum = kernelTiling.totalTileNum / blockNum; \ + uint64_t oneMoreTileCoreNum = kernelTiling.totalTileNum % blockNum; \ + uint64_t startTileIdx, endTileIdx; \ + if (kernelTiling.coreIdx < oneMoreTileCoreNum) { \ + startTileIdx = kernelTiling.coreIdx * (baseTileNum + 1); \ + endTileIdx = startTileIdx + baseTileNum + 1; \ + } else { \ + startTileIdx = kernelTiling.coreIdx * baseTileNum + oneMoreTileCoreNum; \ + endTileIdx = startTileIdx + baseTileNum; \ + } \ + for (uint64_t tileIdx = startTileIdx; tileIdx < endTileIdx; tileIdx++) { \ + kernelTiling.CalcOneTileParam(tileIdx); \ + ROW_COL_QUANT_PROCESS_TILE( \ + kernelTiling.gmOffset, *(kernelTiling.curTileCopyParam), kernelTiling.curCalLen); \ + } \ + } while (0) + + constexpr uint32_t BUFFER_NUM = 1; + static constexpr float FLOAT_127 = 127.0f; + static constexpr float FRACTION_127 = 1.0 / FLOAT_127; + + template + class RowColQuantKernel { + public: + __aicore__ inline RowColQuantKernel() {} + __aicore__ inline ~RowColQuantKernel() {} + __aicore__ inline void Init(GM_ADDR xGm, GM_ADDR rowAbsGm, GM_ADDR colAbsGm, GM_ADDR rowNormedGm, + GM_ADDR colNormedGm, GM_ADDR rowIdxGm, GM_ADDR colIdxGm, GM_ADDR valueGm, + GM_ADDR tilingGm); + + __aicore__ inline void Process(); + protected: + __aicore__ inline void InitGmBuffer(GM_ADDR xGm_, GM_ADDR rowAbsGm_, GM_ADDR colAbsGm_, GM_ADDR rowNormedGm_, + GM_ADDR colNormedGm_, GM_ADDR rowIdx_, GM_ADDR colIdx_, GM_ADDR value_); + __aicore__ inline void InitUbBuffer(); + __aicore__ inline void CopyIn(uint64_t tileOffset, CopyParam ©Param); + __aicore__ inline void Compute(uint32_t curTileLen); + __aicore__ inline void CopyOut(uint64_t tileOffset, CopyParam ©Param); + __aicore__ inline void CopyInForOutlier(uint32_t offset, uint32_t calcLen); + __aicore__ inline void ComputeForOutlier(uint32_t offset, uint32_t calcLen, LocalTensor& outlierIdx, uint32_t& outlierNum); + __aicore__ inline void CopyOutForColIdx(); + + private: + TPipe pipe; + + GlobalTensor xGm; + GlobalTensor colAbsMaxGm, rowAbsMaxGm; + TQue inQueueX, inQueueColMax, inQueueRowMax, inQueueForOutlier, tmpQueueForOutlier; + TBuf xFloatBuffer; + TBuf thresholdDuplicateBuffer; + TBuf bitmapBuffer; + TBuf rowNormedSelectBuffer; + TBuf repeatFloat127Buffer; + + GlobalTensor colNormedGm, rowNormedGm; + GlobalTensor rowIdxGm, colIdxGm; + GlobalTensor valGm; + TQue outQueueColNormed, outQueueRowNormed, outQueueRowIdx, outQueueColIdx, outQueueValue; + uint32_t outlierNum = 0; + LocalTensor colIdxLocal; + TQue *tempQue; + + protected: + RowColQuantTilingKernel tiling; + }; + + template + __aicore__ inline void RowColQuantKernel::Init(GM_ADDR xGm, GM_ADDR rowAbsGm, + GM_ADDR colAbsGm, + GM_ADDR rowNormedGm, GM_ADDR colNormedGm, + GM_ADDR rowIdxGm, GM_ADDR colIdxGm, + GM_ADDR valueGm, GM_ADDR tilingGm) + { + tiling.GetTilingAndOffset(tilingGm, sizeof(InType)); + InitGmBuffer(xGm, rowAbsGm, colAbsGm, rowNormedGm, colNormedGm, rowIdxGm, colIdxGm, valueGm); + InitUbBuffer(); + } + + template + __aicore__ inline void RowColQuantKernel::Process() + { + if (tiling.coreIdx < tiling.usedCoreNum) { + if (tiling.is32BAligned == 1) { + ROW_COL_QUANT_PROCESS(tiling); + } + } + if (tiling.coreIdx < tiling.usedCoreNumForOutlier) { + if (tiling.isOutlierIndex == 0 && tiling.threshold > 0) { + for (uint32_t idx = 0; idx < tiling.curCoreParam.loopNum; idx++) { + uint32_t offset = idx * tiling.curCoreParam.tileCol + tiling.copyInOffset; + CopyInForOutlier(offset, tiling.curCoreParam.tileCol); + ComputeForOutlier(offset, tiling.curCoreParam.tileCol, colIdxLocal, outlierNum); + } + if (tiling.curCoreParam.isTailExist == 1) { + uint32_t offset = tiling.curCoreParam.loopNum * tiling.curCoreParam.tileCol + tiling.copyInOffset; + CopyInForOutlier(offset, tiling.curCoreParam.tailCol); + ComputeForOutlier(offset, tiling.curCoreParam.tailCol, colIdxLocal, outlierNum); + } + } + if (outlierNum > 0) { + CopyOutForColIdx(); + } + } + } + + template + __aicore__ inline void RowColQuantKernel::InitGmBuffer(GM_ADDR xGm_, GM_ADDR rowAbsGm_, + GM_ADDR colAbsGm_, GM_ADDR rowNormedGm_, + GM_ADDR colNormedGm_, GM_ADDR rowIdx_, + GM_ADDR colIdx_, GM_ADDR value_) + { + this->xGm.SetGlobalBuffer((__gm__ InType*)xGm_, tiling.totalBlockLen); + this->rowAbsMaxGm.SetGlobalBuffer((__gm__ CalType*)rowAbsGm_, tiling.rowLen); + this->colAbsMaxGm.SetGlobalBuffer((__gm__ CalType*)colAbsGm_, tiling.colLen); + + this->rowNormedGm.SetGlobalBuffer((__gm__ OutType*)rowNormedGm_, tiling.totalBlockLen); + this->colNormedGm.SetGlobalBuffer((__gm__ OutType*)colNormedGm_, tiling.totalBlockLen); + + // col index + this->rowIdxGm.SetGlobalBuffer((__gm__ int32_t*)rowIdx_, tiling.outliersNum); + this->colIdxGm.SetGlobalBuffer((__gm__ int32_t*)colIdx_, tiling.outliersNum); + this->valGm.SetGlobalBuffer((__gm__ InType*)value_, tiling.outliersNum); + } + + template + __aicore__ inline void RowColQuantKernel::InitUbBuffer() + { + if (tiling.coreIdx < tiling.usedCoreNum) { + pipe.InitBuffer(inQueueX, BUFFER_NUM, tiling.tileLength * sizeof(InType)); + pipe.InitBuffer(inQueueRowMax, BUFFER_NUM, tiling.rowAlignedLen * sizeof(CalType)); + pipe.InitBuffer(thresholdDuplicateBuffer, tiling.tileLength * sizeof(CalType)); + pipe.InitBuffer(bitmapBuffer, tiling.tileLength * sizeof(int8_t)); + pipe.InitBuffer(xFloatBuffer, tiling.tileLength * sizeof(CalType)); + pipe.InitBuffer(rowNormedSelectBuffer, tiling.tileLength * sizeof(CalType)); + pipe.InitBuffer(outQueueRowNormed, BUFFER_NUM, tiling.tileLength * sizeof(OutType)); + + if (tiling.isColQuant == 1){ + pipe.InitBuffer(inQueueColMax, BUFFER_NUM, tiling.baseColLen * sizeof(CalType)); + pipe.InitBuffer(repeatFloat127Buffer, tiling.baseColLen * sizeof(CalType)); + pipe.InitBuffer(outQueueColNormed, BUFFER_NUM, tiling.tileLength * sizeof(OutType)); + } + } + if (tiling.coreIdx < tiling.usedCoreNumForOutlier) { + outlierNum = 0; + if (tiling.isOutlierIndex == 1){ + pipe.InitBuffer(outQueueRowIdx, BUFFER_NUM, tiling.tileLength * sizeof(int32_t)); + pipe.InitBuffer(outQueueColIdx, BUFFER_NUM, tiling.tileLength * sizeof(int32_t)); + pipe.InitBuffer(outQueueValue, BUFFER_NUM, tiling.tileLength * sizeof(InType)); + tempQue = &inQueueX; + } else{ + pipe.InitBuffer(inQueueForOutlier, BUFFER_NUM, tiling.curCoreParam.colLen * sizeof(CalType)); + pipe.InitBuffer(outQueueColIdx, BUFFER_NUM, tiling.curCoreParam.colLen * sizeof(int32_t)); + colIdxLocal = outQueueColIdx.AllocTensor(); + pipe.InitBuffer(tmpQueueForOutlier, BUFFER_NUM, tiling.curCoreParam.colLen * sizeof(CalType)); + tempQue = &tmpQueueForOutlier; + } + } + } + + template + __aicore__ inline void RowColQuantKernel::CopyIn(uint64_t tileOffset, CopyParam ©Param) + { + DataCopyParams copyInParams = {copyParam.blockCount, copyParam.blockLen, copyParam.stride, 0}; + LocalTensor xLocal = inQueueX.AllocTensor(); + ::DataCopy(xLocal, xGm[tileOffset], copyInParams); + inQueueX.EnQue(xLocal); + + LocalTensor rowMaxLocal = inQueueRowMax.AllocTensor(); + ::DataCopy(rowMaxLocal, rowAbsMaxGm[tiling.curTileOffset.rowIndex], tiling.rowAlignedLen); + inQueueRowMax.EnQue(rowMaxLocal); + + if (tiling.isColQuant == 1) { + LocalTensor colMaxLocal = inQueueColMax.AllocTensor(); + ::DataCopy(colMaxLocal, colAbsMaxGm[tiling.curTileOffset.colIndex], tiling.baseColLen); + inQueueColMax.EnQue(colMaxLocal); + } + } + + template + __aicore__ inline void RowColQuantKernel::Compute(uint32_t curCalLen) + { + LocalTensor xLocal = inQueueX.DeQue(); + LocalTensor xFloatLocal = xFloatBuffer.Get(); + ::Cast(xFloatLocal, xLocal, RoundMode::CAST_NONE, curCalLen); + pipe_barrier(PIPE_V); + + LocalTensor rowNormedSelectLocal = rowNormedSelectBuffer.Get(); + if (tiling.threshold > 0){ + ::Abs(rowNormedSelectLocal, xFloatLocal, curCalLen); + + pipe_barrier(PIPE_V); + LocalTensor thresholdDuplicateLocal = thresholdDuplicateBuffer.Get(); + ::Duplicate(thresholdDuplicateLocal, tiling.threshold, curCalLen); + + pipe_barrier(PIPE_V); + LocalTensor bitmapLocal = bitmapBuffer.Get(); + ::Compare(bitmapLocal, rowNormedSelectLocal, thresholdDuplicateLocal, CMPMODE::LT, (curCalLen + 63) / 64 * 64); + pipe_barrier(PIPE_V); + + ::Select(rowNormedSelectLocal, bitmapLocal, xFloatLocal, 0.f, SELMODE::VSEL_TENSOR_SCALAR_MODE, curCalLen); + + if (tiling.isOutlierIndex == 1){ + LocalTensor bitmap16Buf = bitmapLocal.template ReinterpretCast(); + Not(bitmap16Buf, bitmap16Buf, curCalLen / 16); + pipe_barrier(PIPE_V); + uint64_t resv_cnt = 1; + GatherMask(xLocal, xLocal, bitmap16Buf, true, curCalLen, {1, 1, 8, 8}, resv_cnt); + pipe_barrier(PIPE_V); + set_flag(PIPE_V, PIPE_S, EVENT_ID0); + wait_flag(PIPE_V, PIPE_S, EVENT_ID0); + outlierNum = outlierNum + static_cast(resv_cnt); + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + } + } else { + ::Adds(rowNormedSelectLocal, xFloatLocal, 0.f, curCalLen); + } + + pipe_barrier(PIPE_V); + LocalTensor rowMaxLocal = inQueueRowMax.DeQue(); + uint32_t rowBeginOffset = 0; + + for (uint32_t r = 0; r < tiling.curCalRowLen; r++) { + set_flag(PIPE_V, PIPE_S, EVENT_ID0); + wait_flag(PIPE_V, PIPE_S, EVENT_ID0); + auto rowAbsMax = rowMaxLocal.GetValue(r); + rowBeginOffset = r * tiling.curCalColLen; + CalType factor = (rowAbsMax == 0 ? 0.f : FLOAT_127 / rowAbsMax); + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + ::Muls(rowNormedSelectLocal[rowBeginOffset], rowNormedSelectLocal[rowBeginOffset], factor, tiling.curCalColLen); + } + pipe_barrier(PIPE_V); + + inQueueRowMax.FreeTensor(rowMaxLocal); + + LocalTensor tempInt16Local = xLocal.template ReinterpretCast(); + ::Cast(tempInt16Local, rowNormedSelectLocal, RoundMode::CAST_RINT, curCalLen); + pipe_barrier(PIPE_V); + + LocalTensor temphalfLocal = rowNormedSelectLocal.template ReinterpretCast(); + ::Cast(temphalfLocal, tempInt16Local, RoundMode::CAST_NONE, curCalLen); + pipe_barrier(PIPE_V); + + LocalTensor rowNormedLocal = outQueueRowNormed.AllocTensor(); + ::Cast(rowNormedLocal, temphalfLocal, RoundMode::CAST_NONE, curCalLen); + pipe_barrier(PIPE_V); + + outQueueRowNormed.EnQue(rowNormedLocal); + + if (tiling.isColQuant == 1) { + LocalTensor colMaxLocal = inQueueColMax.DeQue(); + LocalTensor repeatFloat127Local = repeatFloat127Buffer.Get(); + ::Duplicate(repeatFloat127Local, FLOAT_127, tiling.curCalColLen); + pipe_barrier(PIPE_V); + ::Div(colMaxLocal, repeatFloat127Local, colMaxLocal, tiling.curCalColLen); + pipe_barrier(PIPE_V); + + for (auto r = 0; r < tiling.curCalRowLen; ++r) { + set_flag(PIPE_V, PIPE_S, EVENT_ID0); + wait_flag(PIPE_V, PIPE_S, EVENT_ID0); + rowBeginOffset = r * tiling.curCalColLen; + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + ::Mul(xFloatLocal[rowBeginOffset], xFloatLocal[rowBeginOffset], colMaxLocal, tiling.curCalColLen); + } + + LocalTensor colNormedLocal = outQueueColNormed.AllocTensor(); + + ::Cast(tempInt16Local, xFloatLocal, RoundMode::CAST_RINT, curCalLen); + pipe_barrier(PIPE_V); + ::Cast(temphalfLocal, tempInt16Local, RoundMode::CAST_NONE, curCalLen); + pipe_barrier(PIPE_V); + ::Cast(colNormedLocal, temphalfLocal, RoundMode::CAST_NONE, curCalLen); + pipe_barrier(PIPE_V); + + outQueueColNormed.EnQue(colNormedLocal); + inQueueColMax.FreeTensor(colMaxLocal); + } + inQueueX.FreeTensor(xLocal); + } + + template + __aicore__ inline void RowColQuantKernel::CopyOut(uint64_t tileOffset, CopyParam ©Param) + { + DataCopyParams copyOutParams = {copyParam.blockCount, copyParam.blockLen_int8, 0, copyParam.stride_int8}; + LocalTensor rowNormedLocal = outQueueRowNormed.DeQue(); + ::DataCopy(rowNormedGm[tileOffset], rowNormedLocal, copyOutParams); + outQueueRowNormed.FreeTensor(rowNormedLocal); + + if (tiling.isColQuant == 1) { + LocalTensor colNormedLocal = outQueueColNormed.DeQue(); + ::DataCopy(colNormedGm[tileOffset], colNormedLocal, copyOutParams); + outQueueColNormed.FreeTensor(colNormedLocal); + } + } + + template + __aicore__ inline void RowColQuantKernel::CopyInForOutlier(uint32_t offset, uint32_t calcLen) + { + LocalTensor colMaxLocal = inQueueForOutlier.AllocTensor(); + ::DataCopy(colMaxLocal, colAbsMaxGm[offset], calcLen); + inQueueForOutlier.EnQue(colMaxLocal); + } + + template + __aicore__ inline void RowColQuantKernel::ComputeForOutlier(uint32_t offset, + uint32_t calcLen, + LocalTensor& outlierIdx, + uint32_t& outlierNum) + { + LocalTensor colMaxLocal = inQueueForOutlier.DeQue(); + set_flag(PIPE_V, PIPE_S, EVENT_ID0); + wait_flag(PIPE_V, PIPE_S, EVENT_ID0); + for (uint32_t c = 0; c < calcLen; c++) { + auto curVal = colMaxLocal.GetValue(c); + if (curVal >= tiling.threshold) { + outlierIdx.SetValue(outlierNum, offset + c); + outlierNum += 1; + } + } + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + inQueueForOutlier.FreeTensor(colMaxLocal); + } + + template + __aicore__ inline void RowColQuantKernel::CopyOutForColIdx() + { + DataCopyParams copyParams{1, static_cast(outlierNum * sizeof(int32_t)), 0, 0}; + ::DataCopyPad(colIdxGm[tiling.coreIdx * tiling.colLen], colIdxLocal, copyParams); + outQueueColIdx.FreeTensor(colIdxLocal); + } +} // namespace row_col_quant_kernel + + +namespace row_col_stats_fp16_kernel { + template + class RowColStatsKernelFp16 { + public: + __aicore__ inline RowColStatsKernelFp16() {} + __aicore__ inline void Init(GM_ADDR x, GM_ADDR rmaxbuf, GM_ADDR cmaxbuf, GM_ADDR cnt, GM_ADDR tiling_gm) + { + auto tiling_host = (__gm__ RowColStatsTiling*)tiling_gm; + tiling.M = tiling_host->M; + tiling.K = tiling_host->K; + tiling.threshold = tiling_host->threshold; + tiling.is_outlier_index = tiling_host->is_outlier_index; + tiling.use_gather_mask = tiling_host->use_gather_mask; + uint32_t blkid = get_block_idx(); + tiling.core_m = tiling_host->core_rows[blkid]; + tiling.core_k = tiling_host->core_cols[blkid]; + uint64_t start_off = tiling_host->start_offs[blkid]; + tiling.align_k = AlignTo16(tiling.core_k); + tiling.align_m = AlignTo16(tiling.core_m); + tiling.align_K = AlignTo16(tiling.K); + tiling.ub_sizes = tiling_host-> ub_sizes; + uint32_t max_elements_per_ub = tiling_host->max_elements_per_ub; + tiling.tile_lines = CalcTileLines(tiling.ub_sizes, max_elements_per_ub, InTypeSize, tiling.align_k, BUFFER_NUM); + tiling.start_off = start_off; + + // number of tile(tileLength per tile) on this core, don't include tail tile + tiling.tile_num = (tiling.tile_lines == 0) ? 0 : tiling.core_m / tiling.tile_lines; + tiling.tail_tile_lines = tiling.core_m - tiling.tile_lines * tiling.tile_num; + tiling.last_tile_idx = (tiling.tail_tile_lines > 0) ? tiling.tile_num : (tiling.tile_num - 1); + + xGm.SetGlobalBuffer((__gm__ InType*)x + start_off, (tiling.M * tiling.K - start_off) * InTypeSize); + rmaxGm.SetGlobalBuffer((__gm__ OutType*)rmaxbuf + (start_off / tiling.K), tiling.core_m * sizeof(OutType)); + cmaxGm.SetGlobalBuffer((__gm__ OutType*)cmaxbuf + (start_off % tiling.K), tiling.core_k * sizeof(OutType)); + cntGm.SetGlobalBuffer((__gm__ int32_t*)cnt, sizeof(int32_t)); + + uint32_t max_lines_per_tile = + tiling.tile_lines > tiling.tail_tile_lines ? tiling.tile_lines : tiling.tail_tile_lines; + pipe.InitBuffer(inQueue, BUFFER_NUM, + AlignToN(max_lines_per_tile * tiling.align_k * InTypeSize, ONE_REPEAT_BYTE_SIZE)); + pipe.InitBuffer(calcTBuf, AlignToN(max_lines_per_tile * tiling.align_k * InTypeSize, ONE_REPEAT_BYTE_SIZE)); + pipe.InitBuffer(rmaxQueue, BUFFER_NUM, AlignToN(max_lines_per_tile * sizeof(OutType), ONE_BLK_SIZE)); + pipe.InitBuffer(cmaxQueue, BUFFER_NUM, tiling.align_k * sizeof(OutType)); + pipe.InitBuffer(cntQueue, 1, 32); + pipe.InitBuffer(bitmapTBuf, AlignToN(max_lines_per_tile * tiling.align_k / UINT8_BITS, ONE_BLK_SIZE)); + calcBuf = calcTBuf.Get(); + bitmapBuf = bitmapTBuf.Get(); + outlier_cnt = 0; + + cntsBuf = cntQueue.template AllocTensor(); + cmaxCalcBuf = cmaxQueue.template AllocTensor(); + LocalTensor cnts32Buf = cntsBuf.template ReinterpretCast(); + Duplicate(cnts32Buf, (int32_t)0, 1); + cntQueue.EnQue(cnts32Buf); + DataCopyExtParams copyParams{1, (uint32_t)sizeof(int32_t), 0, 0, 0}; + cnts32Buf = cntQueue.template DeQue(); + + LocalTensor cmaxFloatBuf = cmaxCalcBuf.template ReinterpretCast(); + Duplicate(cmaxFloatBuf, (float)0.0, tiling.core_k); + cmaxQueue.EnQue(cmaxFloatBuf); + cmaxFloatBuf = cmaxQueue.template DeQue(); + + rmaxCalcBuf = rmaxQueue.template AllocTensor(); + LocalTensor rmaxFloatBuf = rmaxCalcBuf.template ReinterpretCast(); + + Duplicate(rmaxFloatBuf, (float)0.0, + AlignToN(max_lines_per_tile * sizeof(OutType), ONE_BLK_SIZE) / sizeof(OutType)); + + rmaxQueue.EnQue(rmaxFloatBuf); + rmaxFloatBuf = rmaxQueue.template DeQue(); + rmaxQueue.template FreeTensor(rmaxFloatBuf); + } + + __aicore__ inline void Process() + { + do { + for (uint64_t i = 0; i < tiling.tile_num; i++) { + CopyIn(i, tiling.tile_lines); + Compute(i, tiling.tile_lines); + CopyOut(i, tiling.tile_lines); + } + if (tiling.tail_tile_lines > 0) { + CopyIn(tiling.tile_num, tiling.tail_tile_lines); + Compute(tiling.tile_num, tiling.tail_tile_lines); + CopyOut(tiling.tile_num, tiling.tail_tile_lines); + } + } while(0); + } + + private: + __aicore__ inline void CopyIn(int32_t progress, uint32_t cur_lines) + { + LocalTensor xLocal = inQueue.template AllocTensor(); + if (tiling.K == tiling.align_K) { + DataCopyParams copyParams{(uint16_t)cur_lines, (uint16_t)(tiling.core_k / 16), + (uint16_t)((tiling.K - tiling.core_k) / 16), 0}; + DataCopy(xLocal, xGm[progress * tiling.tile_lines * tiling.K], copyParams); + } else { + DataCopyExtParams copyParams{(uint16_t)cur_lines, (uint32_t)(tiling.core_k * InTypeSize), + (uint32_t)(tiling.K - tiling.core_k) * InTypeSize, 0, 0}; + DataCopyPadExtParams padParams{(tiling.core_k != tiling.align_k), 0, + (uint8_t)(tiling.align_k - tiling.core_k), 0}; + DataCopyPad(xLocal, xGm[progress * tiling.tile_lines * tiling.K], copyParams, padParams); + } + inQueue.EnQue(xLocal); + } + + __aicore__ inline void ComputeColMaxs(int32_t progress, uint32_t cur_lines, LocalTensor& xLocal) + { + if (progress == 0) { + Duplicate(cmaxCalcBuf, (InType)0, tiling.align_k); + } + if (cur_lines > 1) { + uint32_t left_num = cur_lines; + uint32_t half_num = left_num >> 1; + Max(calcBuf, xLocal, xLocal[half_num * tiling.align_k], half_num * tiling.align_k); + if (left_num % 2) { + Max(cmaxCalcBuf, cmaxCalcBuf, xLocal[(left_num - 1) * tiling.align_k], tiling.align_k); + } + left_num = half_num; + uint32_t off = 0; + while(left_num > 1) { + half_num = left_num >> 1; + Max(calcBuf[(off + half_num) * tiling.align_k], calcBuf[off * tiling.align_k], + calcBuf[(off + half_num) * tiling.align_k], half_num * tiling.align_k); + set_flag(PIPE_V, PIPE_S, EVENT_ID0); + wait_flag(PIPE_V, PIPE_S, EVENT_ID0); + off += half_num; + left_num -= half_num; + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + } + Max(cmaxCalcBuf, cmaxCalcBuf, calcBuf[off * tiling.align_k], tiling.align_k); + } else { + Max(cmaxCalcBuf, cmaxCalcBuf, xLocal, tiling.align_k); + } + } + + __aicore__ inline void ComputeRowMaxs(int32_t progress, uint32_t cur_lines, LocalTensor& xLocal) + { + uint32_t total_calc_cnt = cur_lines * tiling.align_k; + half threshold = static_cast(tiling.threshold); + ComplexCompareScalar(bitmapBuf, xLocal, threshold, CMPMODE::LT, total_calc_cnt); + ComplexSelectScalar(calcBuf, bitmapBuf, xLocal, (InType)0, total_calc_cnt); + + LocalTensor rmaxBuf = rmaxQueue.template AllocTensor(); + for (uint32_t i = 0; i < cur_lines; i++) { + AscendC::ReduceMax(calcBuf[i], calcBuf[i * tiling.align_k], xLocal, (int32_t)tiling.core_k, false); + } + LocalTensor rmaxFloatBuf = rmaxBuf.template ReinterpretCast(); + Cast(rmaxFloatBuf, calcBuf, RoundMode::CAST_NONE, cur_lines); + rmaxQueue.EnQue(rmaxFloatBuf); + + if (tiling.is_outlier_index) { + CalcTilingOutlierCnts(progress, cur_lines, xLocal); + } + } + + __aicore__ inline void CalcTilingOutlierCnts(int32_t progress, uint32_t cur_lines, LocalTensor& xLocal) + { + uint32_t total_calc_cnt = cur_lines * tiling.align_k; + LocalTensor bitmap16Buf = bitmapBuf.template ReinterpretCast(); + Not(bitmap16Buf, bitmap16Buf, total_calc_cnt / 16); + uint64_t resv_cnt = 1; + GatherMask(calcBuf, calcBuf, bitmap16Buf, true, total_calc_cnt, {1, 1, 8, 8}, resv_cnt); + set_flag(PIPE_V, PIPE_S, EVENT_ID0); + wait_flag(PIPE_V, PIPE_S, EVENT_ID0); + outlier_cnt += (int32_t)resv_cnt; + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + } + + __aicore__ inline void CalcAllOutlierCnts(int32_t progress) + { + LocalTensor cnts32Buf = cntsBuf.template ReinterpretCast(); + Duplicate(cnts32Buf, (int32_t)outlier_cnt, 1); + cntQueue.EnQue(cnts32Buf); + } + + template + __aicore__ inline void CalcAllOutlierCols(LocalTensor& cmaxBuf, LocalTensor& xLocal) + { + T threshold = static_cast(tiling.threshold); + ComplexCompareScalar(bitmapBuf, cmaxBuf, threshold, CMPMODE::GE, tiling.core_k); + LocalTensor cnts32Buf = cntsBuf.template ReinterpretCast(); + LocalTensor cntsTBuf = cntsBuf.template ReinterpretCast(); + LocalTensor calcTBuf = calcBuf.template ReinterpretCast(); + uint64_t resv_cnt = 1; + LocalTensor bitmap16Buf = bitmapBuf.template ReinterpretCast(); + GatherMask(calcTBuf, calcTBuf, bitmap16Buf, true, tiling.core_k, {1, 1, 8, 8}, resv_cnt); + set_flag(PIPE_V, PIPE_S, EVENT_ID0); + wait_flag(PIPE_V, PIPE_S, EVENT_ID0); + outlier_cnt += (int32_t)resv_cnt; + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + Duplicate(cnts32Buf, (int32_t)outlier_cnt, 1); + cntQueue.EnQue(cnts32Buf); + } + + __aicore__ inline void Compute(int32_t progress, uint32_t cur_lines) + { + LocalTensor xLocal = inQueue.template DeQue(); + Abs(xLocal, xLocal, cur_lines * tiling.align_k); + + ComputeColMaxs(progress, cur_lines, xLocal); + + if (tiling.threshold > 0) { + ComputeRowMaxs(progress, cur_lines, xLocal); + if (progress == tiling.last_tile_idx) { + if (tiling.is_outlier_index) { + CalcAllOutlierCnts(progress); + } else if (tiling.core_m == tiling.M) { + CalcAllOutlierCols(cmaxCalcBuf, xLocal); + } + } + } else { + LocalTensor rmaxBuf = rmaxQueue.template AllocTensor(); + uint32_t rid = 0; + for (uint32_t i = 0; i < cur_lines; i++) { + ReduceMax(xLocal[i], xLocal[i * tiling.align_k], calcBuf, (int32_t)tiling.core_k); + } + LocalTensor rmaxFloatBuf = rmaxBuf.template ReinterpretCast(); + Cast(rmaxFloatBuf, xLocal, RoundMode::CAST_NONE, cur_lines); + rmaxQueue.EnQue(rmaxFloatBuf); + + if (progress == tiling.last_tile_idx) { + LocalTensor cnts32Buf = cntsBuf.template ReinterpretCast(); + Duplicate(cnts32Buf, (int32_t)outlier_cnt, 1); + cntQueue.EnQue(cnts32Buf); + } + } + inQueue.FreeTensor(xLocal); + + if (progress == tiling.last_tile_idx) { + LocalTensor cmaxFloatBuf = cmaxCalcBuf.template ReinterpretCast(); + ComplexCopy(calcBuf, cmaxCalcBuf, tiling.core_k); + Cast(cmaxFloatBuf, calcBuf, RoundMode::CAST_NONE, tiling.core_k); + cmaxQueue.EnQue(cmaxFloatBuf); + } + } + + __aicore__ inline void CopyOutRmax(int32_t progress, uint32_t cur_lines) { + LocalTensor rmaxBuf = rmaxQueue.template DeQue(); + if (cur_lines % 16) { + DataCopyExtParams copyParams{1, (uint32_t)(cur_lines * sizeof(OutType)), 0, 0, 0}; + if (tiling.core_k == tiling.K) { + DataCopyPad(rmaxGm[progress * tiling.tile_lines], rmaxBuf, copyParams); + } else { + SetAtomicMax(); + DataCopyPad(rmaxGm[progress * tiling.tile_lines], rmaxBuf, copyParams); + SetAtomicNone(); + } + } else { + if (tiling.core_k == tiling.K) { + DataCopy(rmaxGm[progress * tiling.tile_lines], rmaxBuf, cur_lines); + } else { + SetAtomicMax(); + DataCopy(rmaxGm[progress * tiling.tile_lines], rmaxBuf, cur_lines); + SetAtomicNone(); + } + } + rmaxQueue.FreeTensor(rmaxBuf); + } + + __aicore__ inline void CopyOutCmax() { + LocalTensor cmaxBuf = cmaxQueue.template DeQue(); + if (tiling.core_k == tiling.align_k) { + if (tiling.core_m == tiling.M) { + DataCopy(cmaxGm, cmaxBuf, tiling.core_k); + } else { + SetAtomicMax(); + DataCopy(cmaxGm, cmaxBuf, tiling.core_k); + SetAtomicNone(); + } + } else { + DataCopyExtParams copyParams{1, (uint32_t)(tiling.core_k * sizeof(OutType)), 0, 0, 0}; + if (tiling.core_m == tiling.M) { + DataCopyPad(cmaxGm, cmaxBuf, copyParams); + } else { + SetAtomicMax(); + DataCopyPad(cmaxGm, cmaxBuf, copyParams); + SetAtomicNone(); + } + } + cmaxQueue.FreeTensor(cmaxBuf); + } + + __aicore__ inline void CalcFinalOutlierCols() { + if (tiling.start_off / tiling.K > 0) { + LocalTensor cnts32Buf = cntsBuf.template ReinterpretCast(); + Duplicate(cnts32Buf, (int32_t)0, 1); + cntQueue.EnQue(cnts32Buf); + return; + } + LocalTensor cmaxFBuf = cmaxQueue.template AllocTensor(); + if (tiling.K == tiling.align_K) { + DataCopy(cmaxFBuf, cmaxGm, tiling.core_k); + } else { + uint32_t blockLen = (uint32_t)(tiling.core_k * sizeof(float)); + DataCopyExtParams copyParams{(uint16_t)1, blockLen, 0, 0, 0}; + DataCopyPadExtParams padParams{true, 0, (uint8_t)((AlignTo32(blockLen) - blockLen) / sizeof(float)), 0}; + DataCopyPad(cmaxFBuf, cmaxGm, copyParams, padParams); + } + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + LocalTensor xLocal = inQueue.template AllocTensor(); + CalcAllOutlierCols(cmaxFBuf, xLocal); + inQueue.FreeTensor(xLocal); + cmaxQueue.FreeTensor(cmaxFBuf); + } + + __aicore__ inline void CopyOutOutlierCnt() { + if ((!tiling.is_outlier_index) && (tiling.threshold > 0) && (tiling.core_m != tiling.M)) { + CalcFinalOutlierCols(); + } + DataCopyExtParams copyParams{1, (uint32_t)sizeof(int32_t), 0, 0, 0}; + LocalTensor cnts32Buf = cntQueue.template DeQue(); + SetAtomicAdd(); + DataCopyPad(cntGm, cnts32Buf, copyParams); + SetAtomicNone(); + cntQueue.FreeTensor(cnts32Buf); + } + + __aicore__ inline void CopyOut(int32_t progress, uint32_t cur_lines) + { + CopyOutRmax(progress, cur_lines); + if (progress == tiling.last_tile_idx) { + CopyOutCmax(); + CopyOutOutlierCnt(); + } + } + + __aicore__ inline uint32_t + CalcTileLines(uint32_t ub_size, uint64_t max_datas_per_ub, uint32_t dtype_len, uint32_t align_k, uint32_t buffer_num) { + uint32_t tiling_lines = (ub_size - buffer_num * 320 - 320 - 4 * buffer_num * align_k) * 8 + / ((dtype_len * (buffer_num + 1) * 8 + 1) * align_k + 4 * 8 * buffer_num); + uint32_t align_num = 32 / sizeof(uint16_t); + uint32_t aligned_tiling_lines = tiling_lines / align_num * align_num; + tiling_lines = (aligned_tiling_lines == 0) ? tiling_lines : aligned_tiling_lines; + + return tiling_lines; + } + + static constexpr uint32_t UINT8_BITS = ONE_BYTE_BIT_SIZE; + + template + __aicore__ inline void ComplexCompareScalar(const LocalTensor& dstLocal, const LocalTensor& src0Local, + const T src1Scalar, CMPMODE cmpMode, uint32_t calCount) + { + UnaryRepeatParams repeatParams{1, 1, 1, 8}; + if constexpr (std::is_same_v) { + uint32_t repeat = (calCount % 64) ? (calCount / 64 + 1) : (calCount / 64); + uint32_t off = 0; + while (repeat > 248) { + CompareScalar(dstLocal[off * 8], src0Local[off * 64], src1Scalar, cmpMode, 64, 248, repeatParams); + repeat -= 248; + off += 248; + } + if (repeat > 0) { + CompareScalar(dstLocal[off * 8], src0Local[off * 64], src1Scalar, cmpMode, 64, repeat, repeatParams); + off += repeat; + } + } else if constexpr (std::is_same_v) { + uint32_t repeat = (calCount % 128) ? (calCount / 128 + 1) : (calCount / 128); + uint32_t off = 0; + + while (repeat > 254) { + CompareScalar(dstLocal[off * 16], src0Local[off * 128], src1Scalar, cmpMode, 128, 254, repeatParams); + repeat -= 254; + off += 254; + } + if (repeat > 0) { + CompareScalar(dstLocal[off * 16], src0Local[off * 128], src1Scalar, cmpMode, 128, repeat, repeatParams); + off += repeat; + } + } + } + + template + __aicore__ inline void ComplexSelectScalar(const LocalTensor& dstLocal, const LocalTensor& selMask, + const LocalTensor& src0Local, T src1Scalar, uint32_t calCount) + { + BinaryRepeatParams repeatParams{1, 1, 1, 8, 8, 1}; + if constexpr (std::is_same_v) { + uint32_t repeat = (calCount % 64) ? (calCount / 64 + 1) : (calCount / 64); + uint32_t off = 0; + while (repeat > 248) { + Select(dstLocal[off * 64], selMask[off * 8], src0Local[off * 64], src1Scalar, + SELMODE::VSEL_TENSOR_SCALAR_MODE, 64, 248, repeatParams); + repeat -= 248; + off += 248; + } + if (repeat > 0) { + Select(dstLocal[off * 64], selMask[off * 8], src0Local[off * 64], src1Scalar, + SELMODE::VSEL_TENSOR_SCALAR_MODE, 64, (uint8_t)repeat, repeatParams); + off += repeat; + } + } else if constexpr (std::is_same_v) { + uint32_t repeat = (calCount % 128) ? (calCount / 128 + 1) : (calCount / 128); + uint32_t off = 0; + while (repeat > 254) { + Select(dstLocal[off * 128], selMask[off * 16], src0Local[off * 128], src1Scalar, + SELMODE::VSEL_TENSOR_SCALAR_MODE, 128, 254, repeatParams); + repeat -= 254; + off += 254; + } + if (repeat > 0) { + Select(dstLocal[off * 128], selMask[off * 16], src0Local[off * 128], src1Scalar, + SELMODE::VSEL_TENSOR_SCALAR_MODE, 128, (uint8_t)repeat, repeatParams); + off += repeat; + } + } + } + + template + __aicore__ inline void ComplexSelectScalar2(const LocalTensor& dstLocal, const LocalTensor& selMask, + const LocalTensor& src0Local, T src1Scalar, uint32_t calCount) + { + BinaryRepeatParams repeatParams{1, 0, 1, 8, 0, 8}; + if constexpr (std::is_same_v) { + uint32_t repeat = (calCount % 64) ? (calCount / 64 + 1) : (calCount / 64); + uint32_t off = 0; + while (repeat > 248) { + Select(dstLocal[off * 64], selMask[off * 8], src0Local, src1Scalar, + SELMODE::VSEL_TENSOR_SCALAR_MODE, 64, 248, repeatParams); + repeat -= 248; + off += 248; + } + if (repeat > 0) { + Select(dstLocal[off * 64], selMask[off * 8], src0Local, src1Scalar, + SELMODE::VSEL_TENSOR_SCALAR_MODE, 64, (uint8_t)repeat, repeatParams); + off += repeat; + } + } else if constexpr (std::is_same_v) { + uint32_t repeat = (calCount % 128) ? (calCount / 128 + 1) : (calCount / 128); + uint32_t off = 0; + while (repeat > 254) { + Select(dstLocal[off * 128], selMask[off * 16], src0Local, src1Scalar, + SELMODE::VSEL_TENSOR_SCALAR_MODE, 128, 254, repeatParams); + repeat -= 254; + off += 254; + } + if (repeat > 0) { + Select(dstLocal[off * 128], selMask[off * 16], src0Local, src1Scalar, + SELMODE::VSEL_TENSOR_SCALAR_MODE, 128, (uint8_t)repeat, repeatParams); + } + } + } + + template + __aicore__ inline void ComplexCopy(const LocalTensor &dstLocal, const LocalTensor &srcLocal, uint32_t calCount) + { + uint32_t repeat_eles = 256 / sizeof(T); + uint32_t repeat = calCount / repeat_eles; + CopyRepeatParams repeatParams{1, 1, 8, 8}; + uint32_t off = 0; + while (repeat > 255) { + Copy(dstLocal[off * repeat_eles], srcLocal[off * repeat_eles], repeat_eles, 255, repeatParams); + repeat -= 255; + off += 255; + } + if (repeat > 0) { + Copy(dstLocal[off * repeat_eles], srcLocal[off * repeat_eles], repeat_eles, (uint8_t)repeat, repeatParams); + off += repeat; + } + repeat = calCount % 128; + if (repeat) { + Copy(dstLocal[off * repeat_eles], srcLocal[off * repeat_eles], repeat, (uint8_t)1, repeatParams); + } + } + + __aicore__ inline uint32_t AlignTo16(uint32_t n) + { + return (n + 15) / 16 * 16; + } + + __aicore__ inline uint32_t AlignTo32(uint32_t n) + { + return (n + 31) / 32 * 32; + } + + __aicore__ inline uint32_t AlignToN(uint32_t m, uint32_t n) + { + return (m + n - 1) / n * n; + } + + private: + TPipe pipe; + TQue inQueue; + TQue rmaxQueue; + TQue cmaxQueue; + TQue cntQueue; + GlobalTensor xGm; + GlobalTensor rmaxGm; + GlobalTensor cmaxGm; + GlobalTensor cntGm; + TBuf calcTBuf; + TBuf bitmapTBuf; + LocalTensor calcBuf; + LocalTensor bitmapBuf; + LocalTensor cmaxCalcBuf; + LocalTensor rmaxCalcBuf; + LocalTensor cntsBuf; + RowColStatsTilingKernel tiling; + int32_t outlier_cnt; + static constexpr uint32_t InTypeSize = sizeof(InType); + static constexpr int32_t InTypeStripe = 32 / sizeof(InType); + static constexpr uint32_t VEC_REPEAT_SIZE = DEFAULT_BLOCK_SIZE / sizeof(InType); + }; +} extern "C" { @@ -219,4 +1287,20 @@ __global__ __aicore__ void dequantize_blockwise_fp16_nf4(GM_ADDR A, GM_ADDR absm op.Process(); } +__global__ __aicore__ void row_col_quant(GM_ADDR x, GM_ADDR rowAbsMax, GM_ADDR colAbsMax, GM_ADDR outRowNormed, + GM_ADDR outColNormed, GM_ADDR outliersRowIdx, GM_ADDR outliersColIdx, + GM_ADDR outliersValue, GM_ADDR tiling) +{ + row_col_quant_kernel::RowColQuantKernel op; + op.Init(x, rowAbsMax, colAbsMax, outRowNormed, outColNormed, outliersRowIdx, outliersColIdx, outliersValue, tiling); + op.Process(); } + +__global__ __aicore__ void row_col_stats(GM_ADDR x, GM_ADDR rmax, GM_ADDR cmax, GM_ADDR cnt, GM_ADDR tiling) +{ + row_col_stats_fp16_kernel::RowColStatsKernelFp16 op; + op.Init(x, rmax, cmax, cnt, tiling); + op.Process(); +} + +} \ No newline at end of file diff --git a/csrc/npu_ops.cpp b/csrc/npu_ops.cpp index fb5ecef2f..f2007cd24 100644 --- a/csrc/npu_ops.cpp +++ b/csrc/npu_ops.cpp @@ -1,10 +1,13 @@ #include +#include #include "acl/acl.h" #include "tiling/platform/platform_ascendc.h" #include "npu_ops.h" #include "aclrtlaunch_dequantize_blockwise_fp32_nf4.h" #include "aclrtlaunch_dequantize_blockwise_fp16_nf4.h" +#include "aclrtlaunch_row_col_stats.h" +#include "aclrtlaunch_row_col_quant.h" extern "C" { @@ -35,7 +38,7 @@ void dequantizeBlockwiseNf4(uint8_t *A, uint8_t *absmax, uint8_t *out, uint32_t tilingHost = (struct BlockwiseNf4TilingData *)malloc(tilingSize); uint32_t error = get_dequantize_blockwise_nf4_tiling(blocksize, n, tilingHost); if (error != 0) { - printf("[!] error\n"); + printf("An error occurred.\n"); } uint8_t *tilingDevice = nullptr; aclrtMalloc((void **)&tilingDevice, tilingSize, ACL_MEM_MALLOC_NORMAL_ONLY); @@ -48,4 +51,509 @@ void dequantizeBlockwiseNf4(uint8_t *A, uint8_t *absmax, uint8_t *out, uint32_t aclrtFree(tilingDevice); } +namespace get_row_col_quant_tiling { + const uint32_t DEFAULT_BUFFER_NUM = 1; + const uint32_t TQUE_ROW_COL_FP16_NUM = 1; + const uint32_t TQUE_ROW_COL_FLOAT_NUM = 2; + const uint32_t TQUE_ROW_COL_INT8_NUM = 2; + const uint32_t TQUE_ROW_COL_INT32_NUM = 1; + const uint32_t TBUF_ROW_COL_FLOAT_X_NUM = 2; + const uint32_t TBUF_ROW_COL_FLOAT_THR_REPEAT_NUM = 1; + const uint32_t TBUF_ROW_COL_BITMAP_NUM = 1; + const uint32_t TBUF_ROW_COL_NORM_CAST_NUM = 2; + const uint32_t TBUF_ROW_COL_ROW_SELECT_NUM = 1; + const uint32_t TBUF_ROW_COL_REPEAT_127_NUM = 1; + const uint32_t DTYPE_FLOAT16_SIZE = 2; + + struct RowColQuantCalculator : public RowColQuantTilingData { + public: + bool CalcTiling(uint32_t totalCore, uint64_t ubSize, int32_t dtype); + bool SetTotalShape(int rows, int cols); + bool SetInputAttr(uint32_t outliersNum, float threshold, bool isColQuant, bool isOutlierIndex); + + private: + inline bool CalcTileColMax(uint64_t ubSize, uint16_t bufferNum); + inline bool CalcOutlierTiling(uint32_t totalCore, uint64_t ubSize); + inline uint32_t CalcTileRowMaxByCol(uint64_t ubSize, uint16_t bufferNum, uint64_t tileCol); + inline void SaveOptBaseShape(uint32_t baseRowLen_, uint32_t baseColLen_); + inline uint32_t getBaseColLenUpBound(); + inline uint32_t getBaseRowLenUpBound(); + inline bool MustBeSingleBaseRowLen(uint32_t baseColLen_); + inline bool isInvalidBaseShape(uint32_t baseRowLen_, uint32_t baseColLen_); + inline bool CalcOptBaseShape(uint64_t ubSize); + + uint32_t tileColMax = 0; + uint32_t inputDTypeLen = 2; + // Indicates the minimum processing data unit of the UB. Unit:element. + // Formula: 32B/sizeof(DType). For example, if Dtype is BF16, ubMinBlockLen = 32/2 = 16 + uint32_t ubMinBlockLen = 0; + // Length of the L2 cache line. Unit:element. + // Formula: 512B/sizeof(DType). For example, if the Dtype is BF16, cacheLineLen = 512/2 = 256 + uint32_t cacheLineLen = 0; + // baseColLen aligned package Len. element:Unit. 512-byte alignment or 32-byte alignment + uint32_t alignPackLen = 0; + // Maximum amount of data that can be transferred by an operator UB at a time. Unit:element + uint32_t maxTileLen = 0; + uint32_t optBaseRowLen = 0; + uint32_t optBaseColLen = 0; + uint64_t optTotalTileNum = 0; + uint64_t optBaseSize = 0; + uint64_t optBaseTileNum = 0; + }; + + inline bool GetLengthByType(int32_t dtype, uint32_t& dsize) + { + dsize = sizeof(int16_t); + return true; + } + + inline bool RowColQuantCalculator::SetTotalShape(int rows, int cols) + { + rowLen = rows; + colLen = cols; + return true; + } + + inline bool RowColQuantCalculator::CalcTileColMax(uint64_t ubSize, uint16_t bufferNum) + { + auto base = bufferNum * (sizeof(int16_t) + sizeof(int8_t)) + sizeof(float) * 3 + sizeof(int8_t); + if (isColQuant == 1) { + base += bufferNum * (sizeof(float) + sizeof(int8_t)) + sizeof(float); + } + if (isOutlierIndex == 1) { + base += bufferNum * (sizeof(int32_t) * 2 + sizeof(int16_t)); + } else { + base += bufferNum * (sizeof(float) + sizeof(int32_t)); + } + + tileColMax = ALIGNDOWN((ubSize - 32) / base, L2_CACHE_LINE_SIZE); + return true; + } + + inline uint32_t RowColQuantCalculator::CalcTileRowMaxByCol(uint64_t ubSize, uint16_t bufferNum, uint64_t tileCol) + { + auto base = (bufferNum * (sizeof(int16_t) + sizeof(int8_t)) + sizeof(float) * 3 + sizeof(int8_t)) * tileCol + sizeof(float); + if (isColQuant == 1) { + base += bufferNum * sizeof(int8_t) * tileCol; + ubSize -= (bufferNum * sizeof(float) + sizeof(float)) * tileCol; + } + if (isOutlierIndex == 1) { + base += bufferNum * (sizeof(int32_t) * 2 + sizeof(int16_t)) * tileCol; + } else { + ubSize -= bufferNum * (sizeof(float) + sizeof(int32_t)) * tileCol; + } + + return (ubSize - 32) / base; + } + + inline void RowColQuantCalculator::SaveOptBaseShape(uint32_t baseRowLen_, uint32_t baseColLen_) + { + uint64_t totalTileNum = DIVCEIL(rowLen, baseRowLen_) * DIVCEIL(colLen, baseColLen_); + uint64_t baseSize = baseRowLen_ * baseColLen_; + uint64_t baseTileNum = (rowLen / baseRowLen_) * (colLen / baseColLen_); + + optBaseRowLen = baseRowLen_; + optBaseColLen = baseColLen_; + optTotalTileNum = totalTileNum; + optBaseSize = baseSize; + optBaseTileNum = baseTileNum; + } + + inline uint32_t RowColQuantCalculator::getBaseColLenUpBound() + { + uint32_t upBound = std::min(colLen, (uint64_t)tileColMax); + if (is32BAligned == 1) { + upBound = std::min(upBound, (uint32_t)DISCONTINE_COPY_MAX_BLOCKLEN); + } else { + upBound = std::min(upBound, (uint32_t)DISCONTINE_COPY_MAX_BLOCKLEN / inputDTypeLen); + } + + return upBound; + } + + inline uint32_t RowColQuantCalculator::getBaseRowLenUpBound() + { + return std::min(rowLen, (uint64_t)DISCONTINE_COPY_MAX_BLOCKCNT); + } + + inline bool RowColQuantCalculator::MustBeSingleBaseRowLen(uint32_t baseColLen_) + { + if (is32BAligned == 1) { + return ((colLen * 2 - baseColLen_) > (DISCONTINE_COPY_MAX_STRIDE * ubMinBlockLen)); + } + + return (((colLen * 2 - baseColLen_) * inputDTypeLen) > DISCONTINE_COPY_MAX_STRIDE); + } + + inline bool RowColQuantCalculator::isInvalidBaseShape(uint32_t baseRowLen_, uint32_t baseColLen_) + { + return ((baseRowLen_ < 1) || (baseRowLen_ > 1 && MustBeSingleBaseRowLen(baseColLen_))); + } + + inline bool RowColQuantCalculator::CalcOptBaseShape(uint64_t ubSize) + { + uint32_t baseColLen_ = getBaseColLenUpBound(); + if (MustBeSingleBaseRowLen(baseColLen_)) { + SaveOptBaseShape(1, baseColLen_); + return true; + } + + uint32_t baseRowLen_ = std::min(CalcTileRowMaxByCol(ubSize, DEFAULT_BUFFER_NUM, baseColLen_), getBaseRowLenUpBound()); + if (isInvalidBaseShape(baseRowLen_, baseColLen_)) { + return (optTotalTileNum > 0); + } + SaveOptBaseShape(baseRowLen_, baseColLen_); + + return true; + } + + inline bool RowColQuantCalculator::CalcOutlierTiling(uint32_t totalCore, uint64_t ubSize) { + uint32_t MIN_BLOCK_ALIGN_LEN = UB_MIN_BLOCK_SIZE / sizeof(float); + uint32_t baseCoreCalcLens = ALIGNUP(DIVCEIL(colLen, totalCore), MIN_BLOCK_ALIGN_LEN); + baseCoreNumForOutlier = colLen / baseCoreCalcLens; + usedCoreNumForOutlier = baseCoreNumForOutlier; + baseCoreParam.colLen = baseCoreCalcLens; + if (baseCoreCalcLens >= tileColMax) { + baseCoreParam.loopNum = (uint32_t)baseCoreCalcLens / (uint32_t)tileColMax; + baseCoreParam.tileCol = tileColMax; + } + if (baseCoreCalcLens % tileColMax != 0) { + baseCoreParam.isTailExist = 1; + baseCoreParam.tailCol = baseCoreCalcLens % tileColMax; + } + + if (colLen % baseCoreCalcLens != 0) { + usedCoreNumForOutlier += 1; + tailCoreParam.colLen = ALIGNUP(colLen % baseCoreCalcLens, MIN_BLOCK_ALIGN_LEN); + if (tailCoreParam.colLen >= tileColMax) { + tailCoreParam.loopNum = (uint32_t)tailCoreParam.colLen / (uint32_t)tileColMax; + tailCoreParam.tileCol = tileColMax; + } + if (tailCoreParam.colLen % tileColMax != 0) { + tailCoreParam.isTailExist = 1; + tailCoreParam.tailCol = tailCoreParam.colLen % tileColMax; + } + } + + return true; + } + + bool RowColQuantCalculator::CalcTiling(uint32_t totalCore, uint64_t ubSize, int32_t dtype) + { + if (!GetLengthByType(dtype, inputDTypeLen)) { + printf("Unsupported input data type %d\n", dtype); + return false; + } + ubMinBlockLen = UB_MIN_BLOCK_SIZE / inputDTypeLen; // min block size + cacheLineLen = L2_CACHE_LINE_SIZE / inputDTypeLen; // bandwidth max efficiency + alignPackLen = cacheLineLen; + + ubSize -= UB_RESERVED_BUFF; + if (!CalcTileColMax(ubSize, DEFAULT_BUFFER_NUM)) { + return false; + } + + is32BAligned = colLen % ubMinBlockLen == 0; + + if (!CalcOptBaseShape(ubSize)) { + return false; + } + baseRowLen = optBaseRowLen; + baseColLen = optBaseColLen; + usedCoreNum = std::min(optTotalTileNum, (uint64_t)totalCore); + usedCoreNumForOutlier = usedCoreNum; + if (isOutlierIndex == 0) { + CalcOutlierTiling(totalCore, ubSize); + } + return true; + } + + bool RowColQuantCalculator::SetInputAttr(uint32_t outliers_num, float in_threshold, bool is_col_quant, bool is_outlier_index) + { + outliersNum = outliers_num; + threshold = in_threshold; + isColQuant = (is_col_quant ? 1 : 0); + isOutlierIndex = (is_outlier_index ? 1 : 0); + return true; + } + + uint32_t TilingForRowColQuant(uint32_t outliers_num, float in_threshold, bool is_col_quant, bool is_outlier_index, + int rows, int cols, uint32_t totalCore, + get_row_col_quant_tiling::RowColQuantCalculator *tilingCalc) + { + uint64_t ubSize = 192 * 1024; + if (totalCore < 0 || totalCore >= MAX_CORE_NUMBER || ubSize <= UB_RESERVED_BUFF) { + printf("Compile Info is invalid, coreNum:%u, ubSize:%lu\n", totalCore, ubSize); + return 1; + } + + ubSize -= UB_RESERVED_BUFF; + + if (!tilingCalc->SetInputAttr(outliers_num, in_threshold, is_col_quant, is_outlier_index)) { + printf("Parse input attrs failed\n"); + return 1; + } + + if (!tilingCalc->SetTotalShape(rows, cols) || !tilingCalc->CalcTiling(totalCore, ubSize, DTYPE_FLOAT16_SIZE)) { + return 1; + } + return 0; + } +} + +void rowColQuant(uint8_t *A, uint8_t *rowStats, uint8_t *colStats, uint8_t *outRowNormed, uint8_t *outColNormed, + uint8_t *outliersRowIdx, uint8_t *outliersColIdx, uint8_t *outliersValue, uint32_t outliersNum, + float threshold, int rows, int cols, void* stream) { + uint32_t blockDim = 40; + bool isColQuant = false; + bool isOutlierIndex = false; + size_t tilingSize = sizeof(struct get_row_col_quant_tiling::RowColQuantCalculator); + get_row_col_quant_tiling::RowColQuantCalculator tilingHost; + uint32_t error = get_row_col_quant_tiling::TilingForRowColQuant(outliersNum, threshold, isColQuant, isOutlierIndex, + rows, cols, blockDim, &tilingHost); + if (error != 0) { + printf("An error occurred.\n"); + } + uint8_t *tilingDevice = nullptr; + aclrtMalloc((void **)&tilingDevice, tilingSize, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMemcpyAsync((void *)tilingDevice, tilingSize, &tilingHost, tilingSize, ACL_MEMCPY_HOST_TO_DEVICE, stream); + ACLRT_LAUNCH_KERNEL(row_col_quant)(tilingHost.usedCoreNumForOutlier, stream, A, rowStats, colStats, outRowNormed, + outColNormed, outliersRowIdx, outliersColIdx, outliersValue, tilingDevice); + CHECK_ACL(aclrtSynchronizeStream(stream)); + aclrtFree(tilingDevice); +} + + +namespace get_row_col_stats_tiling { + const uint32_t PACK_SIZE = 512; // pack unit in cache 512B + const uint32_t ALIGN_SIZE = 32; // align unit in cache 32B + const uint32_t DEFAULT_BUFFER_NUM = 1; + + inline uint32_t RoundUpToN(uint32_t m, uint32_t n) + { + return (m / n) * n; + } + + inline uint32_t GetLength(int32_t dtype, uint32_t &dsize) + { + switch (dtype) { + case 0: + case 1: + dsize = sizeof(int16_t); + return true; + default: + return false; + } + } + + struct RowColStatsTilingCalculator : public RowColStatsTiling + { + public: + bool CalcTiling(uint32_t m, uint32_t k, uint32_t core_num, uint64_t ub_size, uint32_t dtype_len, int32_t dtype, + float th, bool is_outlier_idx) + { + M = m; + K = k; + threshold = th; + is_outlier_index = is_outlier_idx; + + uint64_t element_num = m * k; + // align to 32B by hardware + uint32_t align_num = ALIGN_SIZE / dtype_len; + // align to L2 cacheline 512B (for bandwidth max efficiency) + uint32_t pack_align_num = PACK_SIZE / dtype_len; + buffer_num = DEFAULT_BUFFER_NUM; + + ub_sizes = ub_size; + max_elements_per_ub = GetMaxDatasPerUB(ub_size, dtype_len, 0, 0, buffer_num); + uint64_t align_elements_per_ub = (max_elements_per_ub / pack_align_num) * pack_align_num; + + if (element_num <= align_elements_per_ub) { + used_cores = 1; + } else if (element_num >= align_elements_per_ub * core_num) { + used_cores = core_num; + } else { + used_cores = (element_num + align_elements_per_ub - 1) / align_elements_per_ub; + } + + if (K <= 4096 /*align_elements_per_ub / 4*/) { + if (!TilingForRow(core_num, align_num)) { + printf("CalcTiling failed for TilingForRow \n"); + return false; + } + } else if (M <= 4096) { + if (!TilingForCol(core_num, align_num)) { + printf("CalcTiling failed for TilingForCol \n"); + return false; + } + } else { + if ((used_cores != core_num) || (!TilingForBlock(core_num, align_num))) { + printf("CalcTiling failed for TilingForBlock \n"); + return false; + } + } + return true; + } + + bool TilingForRow(uint32_t core_num, uint32_t align_num) { + uint32_t core_k = K; + + uint32_t align_lines = align_num; + uint64_t min_core_lines = (M / (used_cores * align_lines)) * align_lines; + std::fill(core_rows, core_rows + core_num, min_core_lines); + std::fill(core_cols, core_cols + core_num, core_k); + uint64_t left_lines = M - min_core_lines * used_cores; + align_lines = (min_core_lines == 0) ? 1 : align_lines; + uint32_t index = 0; + for (uint64_t len = align_lines; len <= left_lines; len += align_lines) { + core_rows[index % used_cores] += align_lines; + index++; + } + core_rows[used_cores - 1] += M % align_lines; + + uint64_t sum_rows = 0; + for (uint32_t i = 0; i < used_cores; i++) { + start_offs[i] = sum_rows * K; + sum_rows += core_rows[i]; + } + return true; + } + + bool TilingForCol(uint32_t core_num, uint32_t align_num) { + uint32_t core_m = M; + + uint32_t align_lines = align_num; + uint64_t min_core_lines = (K / (used_cores * align_lines)) * align_lines; + std::fill(core_rows, core_rows + core_num, core_m); + std::fill(core_cols, core_cols + core_num, min_core_lines); + uint64_t left_lines = K - min_core_lines * used_cores; + uint32_t index = 0; + for (uint64_t len = align_lines; len <= left_lines; len += align_lines) { + core_cols[index % used_cores] += align_lines; + index++; + } + core_cols[used_cores - 1] += K % align_lines; + + uint64_t sum_cols = 0; + for (uint32_t i = 0; i < used_cores; i++) { + start_offs[i] = sum_cols; + sum_cols += core_cols[i]; + } + return true; + } + + bool TilingForBlock(uint32_t core_num, uint32_t align_num) { + uint32_t rcore_num = 4; + uint32_t ccore_num = used_cores / rcore_num; + + uint32_t align_lines = align_num; + uint64_t min_core_rows = (M / (rcore_num * align_lines)) * align_lines; + uint64_t min_core_cols = (K / (ccore_num * align_lines)) * align_lines; + std::fill(core_rows, core_rows + core_num, min_core_rows); + std::fill(core_cols, core_cols + core_num, min_core_cols); + uint64_t left_rows = M - min_core_rows * rcore_num; + uint32_t index = 0; + for (uint64_t len = align_lines; len <= left_rows; len += align_lines) { + for (uint32_t i = 0; i < ccore_num; i++) { + core_rows[(index % rcore_num) * ccore_num + i] += align_lines; + } + index++; + } + for (uint32_t i = 0; i < ccore_num; i++) { + core_rows[(rcore_num - 1) * ccore_num + i] += M % align_lines; + } + + uint64_t left_cols = K - min_core_cols * ccore_num; + index = 0; + for (uint64_t len = align_lines; len <= left_cols; len += align_lines) { + for (uint32_t i = 0; i < rcore_num; i++) { + core_cols[index % ccore_num + i * ccore_num] += align_lines; + } + index++; + } + for (uint32_t i = 0; i < rcore_num; i++) { + core_cols[ccore_num - 1 + i * ccore_num] += K % align_lines; + } + + uint64_t sum_row = 0; + uint64_t sum_col = 0; + for (uint32_t i = 0; i < rcore_num; i++) { + for (uint32_t j = 0; j < ccore_num; j++) { + start_offs[i * ccore_num + j] = sum_row * K + sum_col; + sum_col += core_cols[j]; + } + sum_col = 0; + sum_row += core_rows[i * ccore_num]; + } + return true; + } + + uint64_t GetMaxDatasPerUB(uint64_t ub_size, uint32_t dtype_len, uint32_t tile_lines, uint32_t align_k, + uint32_t buffer_num) { + float a = ((float) 2 * (buffer_num + 1) + (float) 1 / 8); + float b = (float) 8 * buffer_num; + float c = (float) buffer_num * 320 + 320 - ub_size; + float discriminant = b * b - 4 * a * c; + float result = (2 * b * b - 4 * a * c - 2 * b * sqrt(discriminant)) / (4 * a * a); + return static_cast(std::floor(result)); + } + + inline uint32_t RoundUp(uint32_t a, uint32_t b) + { + return (a + b - 1) / b; + } + + }; + + uint32_t Tiling4RowColStats(int rows, int cols, uint8_t shapeSize, uint32_t core_num, int32_t dtype, + float threshold, bool is_col_quant, bool is_outlier_index, + get_row_col_stats_tiling::RowColStatsTilingCalculator *tiling_calc) + { + uint64_t ub_size = 192 * 1024; + if (core_num <= 0 || core_num > MAX_CORE_NUMBER || ub_size <= UB_RESERVED_BUFF) { + printf(" Compile Info is invalid, coreNum:%u, ubSize:%lu", core_num, ub_size); + return 1; + } + + uint32_t dtype_len = 0; + if (!GetLength(dtype, dtype_len)) { + printf(" Unsupported input data type %d", dtype); + } + int32_t dim = shapeSize; + if (dim > 3 || dim < 2) { + printf(" Unsupported input data shape dim %d", dim); + return 1; + } + int32_t M = rows; + int32_t K = cols; + + if (!tiling_calc->CalcTiling(M, K, core_num, ub_size - UB_RESERVED_BUFF, dtype_len, dtype, threshold, is_outlier_index)) { + return 1; + } + return 0; + } +} + + +void rowColStats(uint8_t *A, uint8_t *rowStats, uint8_t *colStats, uint8_t *outliersNum, float threshold, int rows, int cols, void *stream) { + uint32_t blockDim = 40; + bool is_col_quant = false; + bool is_outlier_index = false; + uint32_t dtype = 1; + uint8_t shapeSize = 2; + size_t tilingSize = sizeof(struct get_row_col_stats_tiling::RowColStatsTilingCalculator); + get_row_col_stats_tiling::RowColStatsTilingCalculator *tilingHost; + tilingHost = (struct get_row_col_stats_tiling::RowColStatsTilingCalculator *)malloc(sizeof(struct get_row_col_stats_tiling::RowColStatsTilingCalculator)); + uint32_t error = get_row_col_stats_tiling::Tiling4RowColStats(rows, cols, shapeSize, blockDim, dtype, threshold, is_col_quant, is_outlier_index, tilingHost); + if (error != 0) { + printf("An error occurred.\n"); + } + uint8_t *tilingDevice = nullptr; + aclrtMalloc((void **)&tilingDevice, tilingSize, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMemcpyAsync((void *)tilingDevice, tilingSize, tilingHost, tilingSize, ACL_MEMCPY_HOST_TO_DEVICE, stream); + ACLRT_LAUNCH_KERNEL(row_col_stats)(tilingHost->used_cores, stream, A, rowStats, colStats, outliersNum, tilingDevice); + CHECK_ACL(aclrtSynchronizeStream(stream)); + aclrtFree(tilingDevice); +} + } diff --git a/csrc/npu_ops.h b/csrc/npu_ops.h index d7a26cd34..a4418d1d4 100644 --- a/csrc/npu_ops.h +++ b/csrc/npu_ops.h @@ -10,6 +10,88 @@ } \ } while (0); +// align num to multiples of rnd, round up +#define ALIGNUP(num, rnd) (((rnd) == 0) ? 0 : (((num) + (rnd) - 1) / (rnd) * (rnd))) +// align num to multiples of rnd, round down +#define ALIGNDOWN(num, rnd) ((((rnd) == 0) || ((num) < (rnd))) ? 0 : ((num) / (rnd) * (rnd))) +// div and Round Up +#define DIVCEIL(num, div) (((div) == 0) ? 0 : (((num) + (div)-1) / (div))) + +const uint32_t UB_RESERVED_BUFF = 8 * 1024; +const uint32_t MAX_CORE_NUMBER = 64; +const uint32_t L2_CACHE_LINE_SIZE = 512; +const uint32_t UB_MIN_BLOCK_SIZE = 32; +const uint32_t MAX_BLOCK_COUNT = 4095; +const uint32_t MAX_BLOCK_LEN = 65535 * 32; +const uint32_t MAX_UINT32 = 4294967295; +const uint16_t DISCONTINE_COPY_MAX_BLOCKCNT = 4095; +const uint16_t DISCONTINE_COPY_MAX_BLOCKLEN = 65535; +const uint16_t DISCONTINE_COPY_MAX_STRIDE = 65535; + +// row_col_quant +struct OutlierTilingParam { + uint64_t colLen = 0; + uint32_t loopNum = 0; + uint32_t tileCol = 0; + uint16_t isTailExist = 0; + uint32_t tailCol = 0; +}; + +struct RowColQuantTilingData { + uint32_t usedCoreNum = 0; // number of vector core. Don't move, must be in the first + uint32_t is32BAligned = 1; + uint32_t isDoubleBuffer = 0; + + uint64_t rowLen = 1; // row length for split vector, Unit:element + uint64_t colLen = 1; // column length for split vector, Unit:element + uint32_t baseRowLen = 2; // for one tile in one core, Unit:element + uint32_t baseColLen = 16; // for one tile in one core, Unit:element + + float threshold = 0.0f; + uint32_t outliersNum = 0; + uint32_t isColQuant = 0; + uint32_t isOutlierIndex = 0; + + uint32_t usedCoreNumForOutlier = 0; + uint32_t baseCoreNumForOutlier = 0; + OutlierTilingParam baseCoreParam; + OutlierTilingParam tailCoreParam; +}; + +// row_col_stats +struct RowColStatsTiling { + uint64_t start_offs[MAX_CORE_NUMBER] = {0}; + uint32_t core_rows[MAX_CORE_NUMBER] = {0}; + uint32_t core_cols[MAX_CORE_NUMBER] = {0}; + uint32_t max_elements_per_ub = 0; + uint32_t used_cores = 0; // number of vector core. Don't move, must be in the first + uint32_t buffer_num = 1; + uint32_t M = 0; + uint32_t K = 0; + uint32_t ub_sizes = 0; + float threshold = 0; + bool is_outlier_index = true; + bool use_gather_mask = true; +}; + +struct RowColStatsTilingKernel { + uint32_t tile_lines = 0; + uint32_t tail_tile_lines = 0; + uint32_t tile_num = 0; + uint32_t last_tile_idx = 0; + uint32_t M = 0; + uint32_t K = 0; + uint32_t core_k = 0; + uint32_t core_m = 0; + uint32_t align_k = 0; + uint32_t align_m = 0; + uint32_t align_K = 0; + uint32_t ub_sizes = 0; + float threshold = 0; + bool is_outlier_index = true; + bool use_gather_mask = true; + uint64_t start_off = 0; +}; struct BlockwiseNf4TilingData { uint32_t coreNum; @@ -24,5 +106,11 @@ extern "C" { void dequantizeBlockwiseNf4(uint8_t *A, uint8_t *absmax, uint8_t *out, uint32_t blocksize, uint32_t n, void* stream, const uint32_t type_mode); +void rowColQuant(uint8_t *A, uint8_t *rowStats, uint8_t *colStats, uint8_t *outRowNormed, uint8_t *outColNormed, + uint8_t *outliersRowIdx, uint8_t *outliersColIdx, uint8_t *outliersValue, uint32_t outliersNum, + float threshold, int rows, int cols, void* stream); + +void rowColStats(uint8_t *A, uint8_t *rowStats, uint8_t *colStats, uint8_t *outliersNum, float threshold, int rows, int cols, void *stream); + } #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 489151a87..e8ffe5502 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -584,6 +584,16 @@ int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, void cdequantize_blockwise_fp16_nf4(uint8_t *A, uint8_t *absmax, uint8_t *out, uint32_t blocksize, uint32_t n, void* stream) { dequantizeBlockwiseNf4(A, absmax, out, blocksize, n, stream, 2); } + + void cget_col_row_stats(uint8_t *A, uint8_t *rowStats, uint8_t *colStats, uint8_t *outliersNum, float threshold, int rows, int cols, void *stream) { + rowColStats(A, rowStats, colStats, outliersNum, threshold, rows, cols, stream); + } + + void cdouble_rowcol_quant(uint8_t *A, uint8_t *rowStats, uint8_t *colStats, uint8_t *outRowNormed, uint8_t *outColNormed, + uint8_t *outliersRowIdx, uint8_t *outliersColIdx, uint8_t *outliersValue, uint32_t outliersNum, + float threshold, int rows, int cols, void *stream) { + rowColQuant(A, rowStats, colStats, outRowNormed, outColNormed, outliersRowIdx, outliersColIdx, outliersValue, outliersNum, threshold, rows, cols, stream); + } #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }