Skip to content

Commit

Permalink
fix(onnx): 支持不同 opset version 的 unsqueeze
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Nov 22, 2023
1 parent 0f587c1 commit 42fb008
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/04kernel/src/kernels/conv/cudnn_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ namespace refactor::kernel {
ASSERT(b->get().shape[0] == y.shape[1], "");
std::vector<dim_t> input(y.rank(), 1);
input[1] = y.shape[1];
*biasExpand = ExpandInfo(
biasExpand.emplace(ExpandInfo(
b->get().dataType,
slice(input.data(), input.size()),
slice(y.shape.data(), y.rank()));
slice(y.shape.data(), y.rank())));
}

// group is not supported
Expand Down
41 changes: 28 additions & 13 deletions src/07onnx/src/operators/unsqueeze.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@
namespace refactor::onnx {
using Op = Unsqueeze;

Op::Unsqueeze() : Operator() {}
Op::Unsqueeze(decltype(axes) axes_) : Operator(), axes(std::move(axes_)) {}

auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox {
ASSERT(attributes.empty(), "Unsqueeze operator should not have attributes");
return OpBox(std::make_unique<Op>());
auto Op::build(ModelContext const &ctx, std::string_view, Attributes attributes) -> OpBox {
auto iter = ctx.find("opset_version");
auto opsetVer = iter != ctx.end() ? iter->second.int_() : StandardOpsetVersion;

if (opsetVer >= 13) {
ASSERT(attributes.empty(), "Unsqueeze operator should not have attributes");
return OpBox(std::make_unique<Op>(std::nullopt));
} else {
return OpBox(std::make_unique<Op>(std::make_optional(attributes.at("axes").ints())));
}
}
auto Op::typeId() -> size_t {
static uint8_t ID = 1;
Expand All @@ -21,19 +28,27 @@ namespace refactor::onnx {
auto Op::valueDependentInputs() const -> InputVec { return {1}; }

auto Op::infer(TensorRefs inputs, InferOptions const &) const -> InferResult {
EXPECT_SIZE(2)
if (inputs.empty()) {
return Err(InferError(ERROR_MSG("Input size error")));
}

auto const &data = inputs[0];
auto const &axes = inputs[1];

if (axes.dataType != DataType::I64 || axes.shape.size() != 1 || !axes.data) {
return Err(InferError(ERROR_MSG("Axes not support")));
slice_t<int64_t> axes_;
if (axes) {
axes_ = slice(axes->data(), axes->size());
} else {
EXPECT_SIZE(2)
auto const &axes__ = inputs[1];
if (axes__.dataType != DataType::I64 || axes__.shape.size() != 1 || !axes__.data) {
return Err(InferError(ERROR_MSG("Axes not support")));
}
EXPECT_VAL(axes__.shape[0], axesSize)
axes_ = slice(axes__.data->get<int64_t>(), axesSize);
}
auto axes_ = axes.data->get<int64_t>();
EXPECT_VAL(axes.shape[0], axesSize)
auto rank = data.rank() + axesSize;

auto rank = data.rank() + axes_.size();
Shape output(rank, DimExpr(-1));
for (auto axis : slice(axes_, axesSize)) {
for (auto axis : axes_) {
if (axis < 0) {
axis += rank;
}
Expand Down
4 changes: 3 additions & 1 deletion src/07onnx/src/operators/unsqueeze.hh
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
#define ONNX_UNSQUEEZE_HH

#include "frontend/operator.h"
#include <optional>

namespace refactor::onnx {
using namespace frontend;

struct Unsqueeze final : public Operator {
std::optional<Ints> axes;

Unsqueeze();
explicit Unsqueeze(decltype(axes));

static OpBox build(ModelContext const &, std::string_view, Attributes);
static size_t typeId();
Expand Down
2 changes: 1 addition & 1 deletion src/07onnx/test/test_unsqueeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ TEST(infer, Unsqueeze) {
std::copy(val, val + 2, reinterpret_cast<int64_t *>(axes->malloc()));
}
count_t inputs[]{0, 1};
auto infered = Unsqueeze().infer(TensorRefs(edges, slice(inputs, 2)), {true});
auto infered = Unsqueeze(std::nullopt).infer(TensorRefs(edges, slice(inputs, 2)), {true});
ASSERT_TRUE(infered.isOk());
auto outputs = std::move(infered.unwrap());
ASSERT_EQ(outputs.size(), 1);
Expand Down

0 comments on commit 42fb008

Please sign in to comment.