Skip to content

Commit

Permalink
feat(kernel): 添加 cast collector
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 7, 2023
1 parent 8116aca commit 0e4a3ce
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 16 deletions.
18 changes: 18 additions & 0 deletions src/04kernel/include/kernel/collectors/cast.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef KERNEL_CAST_H
#define KERNEL_CAST_H

#include "../collector.h"

namespace refactor::kernel {

struct CastCollector final : public InfoCollector {

explicit CastCollector(decltype(_target)) noexcept;

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
};

}// namespace refactor::kernel

#endif// KERNEL_CAST_H
2 changes: 1 addition & 1 deletion src/04kernel/include/kernel/collectors/clip.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace refactor::kernel {

struct ClipCollector final : public InfoCollector {

ClipCollector(decltype(_target)) noexcept;
explicit ClipCollector(decltype(_target)) noexcept;

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
Expand Down
22 changes: 22 additions & 0 deletions src/04kernel/src/collectors/cast.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "kernel/collectors/cast.h"

namespace refactor::kernel {

CastCollector::CastCollector(decltype(_target) target) noexcept
: InfoCollector(target) {}

std::vector<KernelBox>
CastCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
break;
case decltype(_target)::Nvidia:
break;
default:
UNREACHABLEX(void, "Unknown target");
}
return ans;
}

}// namespace refactor::kernel
2 changes: 1 addition & 1 deletion src/04kernel/src/kernels/reduce/cudnn_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace refactor::kernel {
#endif

auto const &x = inputs_[0].get();
return x.dataType.isCpuNumberic()
return x.dataType.isFloat()
? std::make_unique<K>(x.dataType, reduceType_, std::move(axes_), x.shape)
: nullptr;
}
Expand Down
5 changes: 1 addition & 4 deletions src/04kernel/src/kernels/softmax/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,11 @@ namespace refactor::kernel {
: Kernel(), info(std::move(info_)) {}

auto K::build(SoftmaxInfo info) noexcept -> KernelBox {
static const std::unordered_set<decltype(DataType::internal)>
TYPES{DataType::F32, DataType::F64, DataType::FP16, DataType::BF16};

#ifndef USE_CUDA
return nullptr;
#endif

return TYPES.contains(info.type)
return info.type.isFloat()
? std::make_unique<K>(std::move(info))
: nullptr;
}
Expand Down
6 changes: 3 additions & 3 deletions src/05computation/include/computation/operators/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
namespace refactor::computation {

struct Cast final : public Operator {
DataType targetDataType;

constexpr explicit Cast(DataType targetDataType_) noexcept
: Operator(), targetDataType(targetDataType_) {}
constexpr explicit Cast() noexcept : Operator() {}

static size_t typeId() noexcept;
size_t opTypeId() const noexcept final;
std::string_view name() const noexcept final;
kernel::CollectorBox candidateKernels(Target) const noexcept final;
std::string serialize() const noexcept final;
};

}// namespace refactor::computation
Expand Down
16 changes: 14 additions & 2 deletions src/05computation/src/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace refactor::computation {
std::vector<kernel::Edge> edges(graph.edges.size());

std::unordered_map<count_t, count_t> identities;
std::vector<std::string> noKernel;
for (auto [nodeIdx, inputs, outputs] : graph.topology) {
auto const &[op, name] = graph.nodes[nodeIdx];
nodes[nodeIdx] = {nullptr, name};
Expand All @@ -44,8 +45,19 @@ namespace refactor::computation {
return std::cref(*graph.edges[i].tensor);
});
auto candidates = op->candidateKernels(target)->filter(std::move(inputs_), std::move(outputs_));
ASSERT(!candidates.empty(), "No kernel selected for \"{}\"", name);
nodes[nodeIdx].kernel = std::move(candidates.front());
if (!candidates.empty()) {
nodes[nodeIdx].kernel = std::move(candidates.front());
} else {
noKernel.push_back(name);
}
}
if (!noKernel.empty()) {
std::stringstream ss;
ss << "No kernel selected for ";
for (auto x : noKernel) {
ss << '"' << x << "\" ";
}
RUNTIME_ERROR(ss.str());
}

for (auto i : range0_(edges.size())) {
Expand Down
15 changes: 12 additions & 3 deletions src/05computation/src/operators/cast.cc
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
#include "computation/operators/cast.h"
#include "kernel/collectors/cast.h"

namespace refactor::computation {
using Op = Cast;

size_t Cast::typeId() noexcept {
size_t Op::typeId() noexcept {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}
size_t Cast::opTypeId() const noexcept { return typeId(); }
std::string_view Cast::name() const noexcept { return "Cast"; }
size_t Op::opTypeId() const noexcept { return typeId(); }
std::string_view Op::name() const noexcept { return "Cast"; }
auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox {
using Collector_ = kernel::CastCollector;
return std::make_unique<Collector_>(target);
}
auto Op::serialize() const noexcept -> std::string {
return "Cast()";
}

}// namespace refactor::computation
3 changes: 1 addition & 2 deletions src/07onnx/src/operators/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,9 @@ namespace refactor::onnx {
}
return Ok(Tensors{std::move(ans)});
}

auto Op::lower(TensorRefs) const -> computation::OpBox {
using Op_ = computation::Cast;
return std::make_unique<Op_>(to);
return std::make_unique<Op_>();
}

}// namespace refactor::onnx
1 change: 1 addition & 0 deletions src/09python_ffi/src/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace refactor::python_ffi {
CASE(U32);
CASE(U64);
CASE(Bool);
if (dt.is(py::dtype(23))) { return DataType::FP16; }

#undef CASE
RUNTIME_ERROR("unsupported data type.");
Expand Down

0 comments on commit 0e4a3ce

Please sign in to comment.