Skip to content

Commit

Permalink
perf(kernel): batched matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
PanZezhong1725 authored and YdrMaster committed Dec 4, 2023
1 parent 9f2cc3e commit 0c1c117
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 58 deletions.
6 changes: 4 additions & 2 deletions src/04kernel/include/kernel/attributes/matmul_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "kernel/attributes/broadcaster.h"
#include "kernel/attributes/expand_info.h"
#include <variant>

namespace refactor::kernel {

Expand All @@ -11,9 +12,10 @@ namespace refactor::kernel {
float alpha, beta;
bool transA, transB;
size_t m, k, n;
// Expand operation info for biasd
std::optional<ExpandInfo> biasExpand;
// A 2-directional broadcaster that deals with dimensions before the last 2 dimensions
Broadcaster broadcaster;
// A constant batch or a 2-directional broadcaster that deals with dimensions before the last 2 dimensions
std::variant<Broadcaster, size_t> broadcasterOrBatch;

MatMulInfo(Tensor const &, Tensor const &,
std::optional<std::reference_wrapper<Tensor const>>,
Expand Down
12 changes: 8 additions & 4 deletions src/04kernel/src/attributes/matmul_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ namespace refactor::kernel {
slice(output.data(), output.size()));
}

std::variant<Broadcaster, size_t> buildBroadcasterOrBatch(slice_t<dim_t> dimA, slice_t<dim_t> dimB) {
if (std::equal(dimA.begin(), dimA.end(), dimB.begin(), dimB.end())) {
return std::accumulate(dimA.begin(), dimA.end(), (size_t) 1, std::multiplies<size_t>());
}
return Broadcaster({dimA, dimB});
}

MatMulInfo::MatMulInfo(
Tensor const &a, Tensor const &b,
std::optional<std::reference_wrapper<Tensor const>> c,
Expand All @@ -37,10 +44,7 @@ namespace refactor::kernel {
k(transA ? a.shape.rbegin()[1] : a.shape.rbegin()[0]),
n(transB ? b.shape.rbegin()[1] : b.shape.rbegin()[0]),
biasExpand(c ? std::make_optional(buildBias(m, n, a, b, *c)) : std::nullopt),
broadcaster({
slice(a.shape.data(), a.shape.size() - 2),
slice(b.shape.data(), b.shape.size() - 2),
}) {
broadcasterOrBatch(buildBroadcasterOrBatch(slice(a.shape.data(), a.shape.size() - 2), slice(b.shape.data(), b.shape.size() - 2))) {
auto kB = transB ? b.shape.rbegin()[0] : b.shape.rbegin()[1];
ASSERT(k == kB, "MatMul: input shape not matched.");
}
Expand Down
67 changes: 44 additions & 23 deletions src/04kernel/src/kernels/mat_mul/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,29 +52,50 @@ namespace refactor::kernel {
}
}

#define CASE(T) \
case DT::T: { \
using T_ = primitive<DT::T>::type; \
return [alpha = static_cast<T_>(info.alpha), \
beta = static_cast<T_>(info.biasExpand ? info.beta : 0.0f), \
broadcaster = info.broadcaster, \
md, \
stepY = info.m * info.n, \
stepA = info.m * info.k, \
stepB = info.k * info.n, \
biasEx = info.biasExpand \
? std::make_optional(ExpandCpu(*info.biasExpand).lower(res).routine) \
: std::nullopt](runtime::Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { \
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } \
auto A = reinterpret_cast<T_ const *>(inputs[0]); \
auto B = reinterpret_cast<T_ const *>(inputs[1]); \
auto Y = reinterpret_cast<T_ *>(outputs[0]); \
dim_t offset[2]; \
for (size_t i = 0; i < broadcaster.outputsCount; i++) { \
broadcaster.locate(i, offset); \
matrixMultiply(A + stepA * offset[0], B + stepB * offset[1], Y + stepY * i, alpha, beta, md); \
} \
}; \
#define CASE(T) \
case DT::T: { \
using T_ = primitive<DT::T>::type; \
if (std::holds_alternative<Broadcaster>(info.broadcasterOrBatch)) { \
return [alpha = static_cast<T_>(info.alpha), \
beta = static_cast<T_>(info.biasExpand ? info.beta : 0.0f), \
broadcaster = std::get<Broadcaster>(info.broadcasterOrBatch), \
md, \
stepY = info.m * info.n, \
stepA = info.m * info.k, \
stepB = info.k * info.n, \
biasEx = info.biasExpand \
? std::make_optional(ExpandCpu(*info.biasExpand).lower(res).routine) \
: std::nullopt](runtime::Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { \
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } \
auto A = reinterpret_cast<T_ const *>(inputs[0]); \
auto B = reinterpret_cast<T_ const *>(inputs[1]); \
auto Y = reinterpret_cast<T_ *>(outputs[0]); \
dim_t offset[2]; \
for (size_t i = 0; i < broadcaster.outputsCount; i++) { \
broadcaster.locate(i, offset); \
matrixMultiply(A + stepA * offset[0], B + stepB * offset[1], Y + stepY * i, alpha, beta, md); \
} \
}; \
} else { \
return [alpha = static_cast<T_>(info.alpha), \
beta = static_cast<T_>(info.biasExpand ? info.beta : 0.0f), \
batch = std::get<size_t>(info.broadcasterOrBatch), \
md, \
stepY = info.m * info.n, \
stepA = info.m * info.k, \
stepB = info.k * info.n, \
biasEx = info.biasExpand \
? std::make_optional(ExpandCpu(*info.biasExpand).lower(res).routine) \
: std::nullopt](runtime::Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { \
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } \
auto A = reinterpret_cast<T_ const *>(inputs[0]); \
auto B = reinterpret_cast<T_ const *>(inputs[1]); \
auto Y = reinterpret_cast<T_ *>(outputs[0]); \
for (size_t i = 0; i < batch; i++) { \
matrixMultiply(A + stepA * i, B + stepB * i, Y + stepY * i, alpha, beta, md); \
} \
}; \
} \
}

auto K::lower(Resources &res) const noexcept -> RoutineWorkspace {
Expand Down
89 changes: 60 additions & 29 deletions src/04kernel/src/kernels/mat_mul/cublas_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,67 @@ namespace refactor::kernel {
static auto lowerTyped(cudaDataType_t cudaDataType,
MatMulInfo info,
Resources &res) noexcept -> RoutineWorkspace {
return [cudaDataType,
alpha = static_cast<T>(info.alpha),
beta = static_cast<T>(info.biasExpand ? info.beta : 0.0f),
tA = info.transA ? CUBLAS_OP_T : CUBLAS_OP_N,
tB = info.transB ? CUBLAS_OP_T : CUBLAS_OP_N,
m = info.m, n = info.n, k = info.k,
strideY = info.m * info.n,
strideA = info.m * info.k,
strideB = info.k * info.n,
lda = info.transA ? info.m : info.k,
ldb = info.transB ? info.k : info.n,
biasEx = info.biasExpand
? std::make_optional(ExpandCuda(*info.biasExpand).lower(res).routine)
: std::nullopt,
broadcaster = info.broadcaster](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); }
if (std::holds_alternative<size_t>(info.broadcasterOrBatch)) {
return [cudaDataType,
alpha = static_cast<T>(info.alpha),
beta = static_cast<T>(info.biasExpand ? info.beta : 0.0f),
tA = info.transA ? CUBLAS_OP_T : CUBLAS_OP_N,
tB = info.transB ? CUBLAS_OP_T : CUBLAS_OP_N,
m = info.m, n = info.n, k = info.k,
strideY = info.m * info.n,
strideA = info.m * info.k,
strideB = info.k * info.n,
lda = info.transA ? info.m : info.k,
ldb = info.transB ? info.k : info.n,
biasEx = info.biasExpand
? std::make_optional(ExpandCuda(*info.biasExpand).lower(res).routine)
: std::nullopt,
batch = std::get<size_t>(info.broadcasterOrBatch)](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
// Call expand kernel to broadcast bias if bias is used
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); }

auto handle = res.fetchOrStore<CublasContext>()->handle;
auto a = reinterpret_cast<T const *>(inputs[0]);
auto b = reinterpret_cast<T const *>(inputs[1]);
auto y = reinterpret_cast<T *>(outputs[0]);
uint32_t offset[2];
for (auto i : range0_(broadcaster.outputsCount)) {
broadcaster.locate(i, offset);
auto stat = cublasGemmEx(
handle, tB, tA, n, m, k, &alpha, b + strideB * offset[1],
cudaDataType, ldb, a + strideA * offset[0], cudaDataType, lda, &beta, y + strideY * i,
cudaDataType, n, cudaDataType, CUBLAS_GEMM_DEFAULT);
}
};
auto handle = res.fetchOrStore<CublasContext>()->handle;
auto a = reinterpret_cast<T const *>(inputs[0]);
auto b = reinterpret_cast<T const *>(inputs[1]);
auto y = reinterpret_cast<T *>(outputs[0]);
auto stat = cublasGemmStridedBatchedEx(
handle, tB, tA, n, m, k, &alpha, b,
cudaDataType, ldb, strideB, a, cudaDataType, lda, strideA,
&beta, y, cudaDataType, n, m * n, batch, cudaDataType,
CUBLAS_GEMM_DEFAULT);
};
} else {// if use boradcaster
return [cudaDataType,
alpha = static_cast<T>(info.alpha),
beta = static_cast<T>(info.biasExpand ? info.beta : 0.0f),
tA = info.transA ? CUBLAS_OP_T : CUBLAS_OP_N,
tB = info.transB ? CUBLAS_OP_T : CUBLAS_OP_N,
m = info.m, n = info.n, k = info.k,
strideY = info.m * info.n,
strideA = info.m * info.k,
strideB = info.k * info.n,
lda = info.transA ? info.m : info.k,
ldb = info.transB ? info.k : info.n,
biasEx = info.biasExpand
? std::make_optional(ExpandCuda(*info.biasExpand).lower(res).routine)
: std::nullopt,
broadcaster = std::get<Broadcaster>(info.broadcasterOrBatch)](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); }

auto handle = res.fetchOrStore<CublasContext>()->handle;
auto a = reinterpret_cast<T const *>(inputs[0]);
auto b = reinterpret_cast<T const *>(inputs[1]);
auto y = reinterpret_cast<T *>(outputs[0]);
uint32_t offset[2];
for (auto i : range0_(broadcaster.outputsCount)) {
broadcaster.locate(i, offset);
auto stat = cublasGemmEx(
handle, tB, tA, n, m, k, &alpha, b + strideB * offset[1],
cudaDataType, ldb, a + strideA * offset[0], cudaDataType, lda, &beta, y + strideY * i,
cudaDataType, n, cudaDataType, CUBLAS_GEMM_DEFAULT);
}
};
}
}

auto MatMulCublas::lower(Resources &res) const noexcept -> RoutineWorkspace {
Expand Down

0 comments on commit 0c1c117

Please sign in to comment.