Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

性能Matmul:batched matmul #55

Merged
merged 1 commit into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/04kernel/include/kernel/attributes/matmul_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "kernel/attributes/broadcaster.h"
#include "kernel/attributes/expand_info.h"
#include <variant>

namespace refactor::kernel {

Expand All @@ -11,9 +12,10 @@ namespace refactor::kernel {
float alpha, beta;
bool transA, transB;
size_t m, k, n;
// Expand operation info for biasd
std::optional<ExpandInfo> biasExpand;
// A 2-directional broadcaster that deals with dimensions before the last 2 dimensions
Broadcaster broadcaster;
// A constant batch or a 2-directional broadcaster that deals with dimensions before the last 2 dimensions
std::variant<Broadcaster, size_t> broadcasterOrBatch;

MatMulInfo(Tensor const &, Tensor const &,
std::optional<std::reference_wrapper<Tensor const>>,
Expand Down
12 changes: 8 additions & 4 deletions src/04kernel/src/attributes/matmul_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ namespace refactor::kernel {
slice(output.data(), output.size()));
}

std::variant<Broadcaster, size_t> buildBroadcasterOrBatch(slice_t<dim_t> dimA, slice_t<dim_t> dimB) {
if (std::equal(dimA.begin(), dimA.end(), dimB.begin(), dimB.end())) {
return std::accumulate(dimA.begin(), dimA.end(), (size_t) 1, std::multiplies<size_t>());
}
return Broadcaster({dimA, dimB});
}

MatMulInfo::MatMulInfo(
Tensor const &a, Tensor const &b,
std::optional<std::reference_wrapper<Tensor const>> c,
Expand All @@ -37,10 +44,7 @@ namespace refactor::kernel {
k(transA ? a.shape.rbegin()[1] : a.shape.rbegin()[0]),
n(transB ? b.shape.rbegin()[1] : b.shape.rbegin()[0]),
biasExpand(c ? std::make_optional(buildBias(m, n, a, b, *c)) : std::nullopt),
broadcaster({
slice(a.shape.data(), a.shape.size() - 2),
slice(b.shape.data(), b.shape.size() - 2),
}) {
broadcasterOrBatch(buildBroadcasterOrBatch(slice(a.shape.data(), a.shape.size() - 2), slice(b.shape.data(), b.shape.size() - 2))) {
auto kB = transB ? b.shape.rbegin()[0] : b.shape.rbegin()[1];
ASSERT(k == kB, "MatMul: input shape not matched.");
}
Expand Down
67 changes: 44 additions & 23 deletions src/04kernel/src/kernels/mat_mul/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,29 +52,50 @@ namespace refactor::kernel {
}
}

#define CASE(T) \
case DT::T: { \
using T_ = primitive<DT::T>::type; \
return [alpha = static_cast<T_>(info.alpha), \
beta = static_cast<T_>(info.biasExpand ? info.beta : 0.0f), \
broadcaster = info.broadcaster, \
md, \
stepY = info.m * info.n, \
stepA = info.m * info.k, \
stepB = info.k * info.n, \
biasEx = info.biasExpand \
? std::make_optional(ExpandCpu(*info.biasExpand).lower(res).routine) \
: std::nullopt](runtime::Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { \
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } \
auto A = reinterpret_cast<T_ const *>(inputs[0]); \
auto B = reinterpret_cast<T_ const *>(inputs[1]); \
auto Y = reinterpret_cast<T_ *>(outputs[0]); \
dim_t offset[2]; \
for (size_t i = 0; i < broadcaster.outputsCount; i++) { \
broadcaster.locate(i, offset); \
matrixMultiply(A + stepA * offset[0], B + stepB * offset[1], Y + stepY * i, alpha, beta, md); \
} \
}; \
#define CASE(T) \
case DT::T: { \
using T_ = primitive<DT::T>::type; \
if (std::holds_alternative<Broadcaster>(info.broadcasterOrBatch)) { \
return [alpha = static_cast<T_>(info.alpha), \
beta = static_cast<T_>(info.biasExpand ? info.beta : 0.0f), \
broadcaster = std::get<Broadcaster>(info.broadcasterOrBatch), \
md, \
stepY = info.m * info.n, \
stepA = info.m * info.k, \
stepB = info.k * info.n, \
biasEx = info.biasExpand \
? std::make_optional(ExpandCpu(*info.biasExpand).lower(res).routine) \
: std::nullopt](runtime::Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { \
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } \
auto A = reinterpret_cast<T_ const *>(inputs[0]); \
auto B = reinterpret_cast<T_ const *>(inputs[1]); \
auto Y = reinterpret_cast<T_ *>(outputs[0]); \
dim_t offset[2]; \
for (size_t i = 0; i < broadcaster.outputsCount; i++) { \
broadcaster.locate(i, offset); \
matrixMultiply(A + stepA * offset[0], B + stepB * offset[1], Y + stepY * i, alpha, beta, md); \
} \
}; \
} else { \
return [alpha = static_cast<T_>(info.alpha), \
beta = static_cast<T_>(info.biasExpand ? info.beta : 0.0f), \
batch = std::get<size_t>(info.broadcasterOrBatch), \
md, \
stepY = info.m * info.n, \
stepA = info.m * info.k, \
stepB = info.k * info.n, \
biasEx = info.biasExpand \
? std::make_optional(ExpandCpu(*info.biasExpand).lower(res).routine) \
: std::nullopt](runtime::Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { \
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } \
auto A = reinterpret_cast<T_ const *>(inputs[0]); \
auto B = reinterpret_cast<T_ const *>(inputs[1]); \
auto Y = reinterpret_cast<T_ *>(outputs[0]); \
for (size_t i = 0; i < batch; i++) { \
matrixMultiply(A + stepA * i, B + stepB * i, Y + stepY * i, alpha, beta, md); \
} \
}; \
} \
}

auto K::lower(Resources &res) const noexcept -> RoutineWorkspace {
Expand Down
89 changes: 60 additions & 29 deletions src/04kernel/src/kernels/mat_mul/cublas_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,67 @@ namespace refactor::kernel {
static auto lowerTyped(cudaDataType_t cudaDataType,
MatMulInfo info,
Resources &res) noexcept -> RoutineWorkspace {
return [cudaDataType,
alpha = static_cast<T>(info.alpha),
beta = static_cast<T>(info.biasExpand ? info.beta : 0.0f),
tA = info.transA ? CUBLAS_OP_T : CUBLAS_OP_N,
tB = info.transB ? CUBLAS_OP_T : CUBLAS_OP_N,
m = info.m, n = info.n, k = info.k,
strideY = info.m * info.n,
strideA = info.m * info.k,
strideB = info.k * info.n,
lda = info.transA ? info.m : info.k,
ldb = info.transB ? info.k : info.n,
biasEx = info.biasExpand
? std::make_optional(ExpandCuda(*info.biasExpand).lower(res).routine)
: std::nullopt,
broadcaster = info.broadcaster](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); }
if (std::holds_alternative<size_t>(info.broadcasterOrBatch)) {
return [cudaDataType,
alpha = static_cast<T>(info.alpha),
beta = static_cast<T>(info.biasExpand ? info.beta : 0.0f),
tA = info.transA ? CUBLAS_OP_T : CUBLAS_OP_N,
tB = info.transB ? CUBLAS_OP_T : CUBLAS_OP_N,
m = info.m, n = info.n, k = info.k,
strideY = info.m * info.n,
strideA = info.m * info.k,
strideB = info.k * info.n,
lda = info.transA ? info.m : info.k,
ldb = info.transB ? info.k : info.n,
biasEx = info.biasExpand
? std::make_optional(ExpandCuda(*info.biasExpand).lower(res).routine)
: std::nullopt,
batch = std::get<size_t>(info.broadcasterOrBatch)](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
// Call expand kernel to broadcast bias if bias is used
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); }

auto handle = res.fetchOrStore<CublasContext>()->handle;
auto a = reinterpret_cast<T const *>(inputs[0]);
auto b = reinterpret_cast<T const *>(inputs[1]);
auto y = reinterpret_cast<T *>(outputs[0]);
uint32_t offset[2];
for (auto i : range0_(broadcaster.outputsCount)) {
broadcaster.locate(i, offset);
auto stat = cublasGemmEx(
handle, tB, tA, n, m, k, &alpha, b + strideB * offset[1],
cudaDataType, ldb, a + strideA * offset[0], cudaDataType, lda, &beta, y + strideY * i,
cudaDataType, n, cudaDataType, CUBLAS_GEMM_DEFAULT);
}
};
auto handle = res.fetchOrStore<CublasContext>()->handle;
auto a = reinterpret_cast<T const *>(inputs[0]);
auto b = reinterpret_cast<T const *>(inputs[1]);
auto y = reinterpret_cast<T *>(outputs[0]);
auto stat = cublasGemmStridedBatchedEx(
handle, tB, tA, n, m, k, &alpha, b,
cudaDataType, ldb, strideB, a, cudaDataType, lda, strideA,
&beta, y, cudaDataType, n, m * n, batch, cudaDataType,
CUBLAS_GEMM_DEFAULT);
};
} else {// if use boradcaster
return [cudaDataType,
alpha = static_cast<T>(info.alpha),
beta = static_cast<T>(info.biasExpand ? info.beta : 0.0f),
tA = info.transA ? CUBLAS_OP_T : CUBLAS_OP_N,
tB = info.transB ? CUBLAS_OP_T : CUBLAS_OP_N,
m = info.m, n = info.n, k = info.k,
strideY = info.m * info.n,
strideA = info.m * info.k,
strideB = info.k * info.n,
lda = info.transA ? info.m : info.k,
ldb = info.transB ? info.k : info.n,
biasEx = info.biasExpand
? std::make_optional(ExpandCuda(*info.biasExpand).lower(res).routine)
: std::nullopt,
broadcaster = std::get<Broadcaster>(info.broadcasterOrBatch)](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); }

auto handle = res.fetchOrStore<CublasContext>()->handle;
auto a = reinterpret_cast<T const *>(inputs[0]);
auto b = reinterpret_cast<T const *>(inputs[1]);
auto y = reinterpret_cast<T *>(outputs[0]);
uint32_t offset[2];
for (auto i : range0_(broadcaster.outputsCount)) {
broadcaster.locate(i, offset);
auto stat = cublasGemmEx(
handle, tB, tA, n, m, k, &alpha, b + strideB * offset[1],
cudaDataType, ldb, a + strideA * offset[0], cudaDataType, lda, &beta, y + strideY * i,
cudaDataType, n, cudaDataType, CUBLAS_GEMM_DEFAULT);
}
};
}
}

auto MatMulCublas::lower(Resources &res) const noexcept -> RoutineWorkspace {
Expand Down