From bb8b656ce9157ec880d623dc1c7120435935ce1a Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 25 Mar 2024 17:25:19 +0800 Subject: [PATCH] =?UTF-8?q?perf(kernel):=20=E4=B8=BA=20Transpose=20?= =?UTF-8?q?=E8=A1=A5=E5=85=85=20reform?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../kernel/attributes/transpose_info.h | 3 +++ src/04kernel/src/attributes/transpose_info.cc | 20 +++++++++++++++++++ .../src/kernels/transpose/cuda_kernel.cc | 2 +- 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/04kernel/include/kernel/attributes/transpose_info.h b/src/04kernel/include/kernel/attributes/transpose_info.h index 00c49315..559ddb82 100644 --- a/src/04kernel/include/kernel/attributes/transpose_info.h +++ b/src/04kernel/include/kernel/attributes/transpose_info.h @@ -23,6 +23,9 @@ namespace refactor::kernel { TransposeInfo(DataType, Shape const &, Permutation const &); dim_t locate(dim_t) const noexcept; + + TransposeInfo reform(dim_t maxblockSize) const noexcept; + void reformAssign(dim_t maxblockSize) noexcept; }; }// namespace refactor::kernel diff --git a/src/04kernel/src/attributes/transpose_info.cc b/src/04kernel/src/attributes/transpose_info.cc index 9ae385a9..d39be305 100644 --- a/src/04kernel/src/attributes/transpose_info.cc +++ b/src/04kernel/src/attributes/transpose_info.cc @@ -118,4 +118,24 @@ namespace refactor::kernel { return ans; } + TransposeInfo TransposeInfo::reform(dim_t maxblockSize) const noexcept { + auto ans = *this; + ans.reformAssign(maxblockSize); + return ans; + } + + void TransposeInfo::reformAssign(dim_t maxblockSize) noexcept { + auto blockSize_ = std::gcd(blockSize, maxblockSize); + if (blockSize_ == blockSize) { return; } + auto times = blockSize / blockSize_; + blockCount *= times; + blockSize = blockSize_; + for (auto &d : dims) { + d.strideO *= times; + d.strideI *= times; + } + dims.resize(dims.size() + 1); + dims.back() = {1, 1}; + } + }// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/transpose/cuda_kernel.cc b/src/04kernel/src/kernels/transpose/cuda_kernel.cc index aed8d7b6..a557dc82 100644 --- a/src/04kernel/src/kernels/transpose/cuda_kernel.cc +++ b/src/04kernel/src/kernels/transpose/cuda_kernel.cc @@ -5,7 +5,7 @@ namespace refactor::kernel { using Info = TransposeInfo; K::TransposeCuda(Info info_) noexcept - : Kernel(), info(std::move(info_)) {} + : Kernel(), info(info_.reform(16)) {} auto K::build(Info info) noexcept -> KernelBox { #ifndef USE_CUDA