From 42fb008b71e360e94305034d58c1776f9bc35f16 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Wed, 22 Nov 2023 17:50:22 +0800 Subject: [PATCH] =?UTF-8?q?fix(onnx):=20=E6=94=AF=E6=8C=81=E4=B8=8D?= =?UTF-8?q?=E5=90=8C=20opset=20version=20=E7=9A=84=20unsqueeze?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- src/04kernel/src/kernels/conv/cudnn_kernel.cc | 4 +- src/07onnx/src/operators/unsqueeze.cc | 41 +++++++++++++------ src/07onnx/src/operators/unsqueeze.hh | 4 +- src/07onnx/test/test_unsqueeze.cpp | 2 +- 4 files changed, 34 insertions(+), 17 deletions(-) diff --git a/src/04kernel/src/kernels/conv/cudnn_kernel.cc b/src/04kernel/src/kernels/conv/cudnn_kernel.cc index e6d368ad..12bf09c8 100644 --- a/src/04kernel/src/kernels/conv/cudnn_kernel.cc +++ b/src/04kernel/src/kernels/conv/cudnn_kernel.cc @@ -22,10 +22,10 @@ namespace refactor::kernel { ASSERT(b->get().shape[0] == y.shape[1], ""); std::vector input(y.rank(), 1); input[1] = y.shape[1]; - *biasExpand = ExpandInfo( + biasExpand.emplace(ExpandInfo( b->get().dataType, slice(input.data(), input.size()), - slice(y.shape.data(), y.rank())); + slice(y.shape.data(), y.rank()))); } // group is not supported diff --git a/src/07onnx/src/operators/unsqueeze.cc b/src/07onnx/src/operators/unsqueeze.cc index 8ea932d6..dfd6dc6e 100644 --- a/src/07onnx/src/operators/unsqueeze.cc +++ b/src/07onnx/src/operators/unsqueeze.cc @@ -5,11 +5,18 @@ namespace refactor::onnx { using Op = Unsqueeze; - Op::Unsqueeze() : Operator() {} + Op::Unsqueeze(decltype(axes) axes_) : Operator(), axes(std::move(axes_)) {} - auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - ASSERT(attributes.empty(), "Unsqueeze operator should not have attributes"); - return OpBox(std::make_unique()); + auto Op::build(ModelContext const &ctx, std::string_view, Attributes attributes) -> OpBox { + auto iter = ctx.find("opset_version"); + auto opsetVer = iter != ctx.end() ? iter->second.int_() : StandardOpsetVersion; + + if (opsetVer >= 13) { + ASSERT(attributes.empty(), "Unsqueeze operator should not have attributes"); + return OpBox(std::make_unique(std::nullopt)); + } else { + return OpBox(std::make_unique(std::make_optional(attributes.at("axes").ints()))); + } } auto Op::typeId() -> size_t { static uint8_t ID = 1; @@ -21,19 +28,27 @@ namespace refactor::onnx { auto Op::valueDependentInputs() const -> InputVec { return {1}; } auto Op::infer(TensorRefs inputs, InferOptions const &) const -> InferResult { - EXPECT_SIZE(2) + if (inputs.empty()) { + return Err(InferError(ERROR_MSG("Input size error"))); + } auto const &data = inputs[0]; - auto const &axes = inputs[1]; - - if (axes.dataType != DataType::I64 || axes.shape.size() != 1 || !axes.data) { - return Err(InferError(ERROR_MSG("Axes not support"))); + slice_t axes_; + if (axes) { + axes_ = slice(axes->data(), axes->size()); + } else { + EXPECT_SIZE(2) + auto const &axes__ = inputs[1]; + if (axes__.dataType != DataType::I64 || axes__.shape.size() != 1 || !axes__.data) { + return Err(InferError(ERROR_MSG("Axes not support"))); + } + EXPECT_VAL(axes__.shape[0], axesSize) + axes_ = slice(axes__.data->get(), axesSize); } - auto axes_ = axes.data->get(); - EXPECT_VAL(axes.shape[0], axesSize) - auto rank = data.rank() + axesSize; + + auto rank = data.rank() + axes_.size(); Shape output(rank, DimExpr(-1)); - for (auto axis : slice(axes_, axesSize)) { + for (auto axis : axes_) { if (axis < 0) { axis += rank; } diff --git a/src/07onnx/src/operators/unsqueeze.hh b/src/07onnx/src/operators/unsqueeze.hh index 02d4e99e..0ef8eede 100644 --- a/src/07onnx/src/operators/unsqueeze.hh +++ b/src/07onnx/src/operators/unsqueeze.hh @@ -2,13 +2,15 @@ #define ONNX_UNSQUEEZE_HH #include "frontend/operator.h" +#include namespace refactor::onnx { using namespace frontend; struct Unsqueeze final : public Operator { + std::optional axes; - Unsqueeze(); + explicit Unsqueeze(decltype(axes)); static OpBox build(ModelContext const &, std::string_view, Attributes); static size_t typeId(); diff --git a/src/07onnx/test/test_unsqueeze.cpp b/src/07onnx/test/test_unsqueeze.cpp index 4b7ce930..66e3e83d 100644 --- a/src/07onnx/test/test_unsqueeze.cpp +++ b/src/07onnx/test/test_unsqueeze.cpp @@ -18,7 +18,7 @@ TEST(infer, Unsqueeze) { std::copy(val, val + 2, reinterpret_cast(axes->malloc())); } count_t inputs[]{0, 1}; - auto infered = Unsqueeze().infer(TensorRefs(edges, slice(inputs, 2)), {true}); + auto infered = Unsqueeze(std::nullopt).infer(TensorRefs(edges, slice(inputs, 2)), {true}); ASSERT_TRUE(infered.isOk()); auto outputs = std::move(infered.unwrap()); ASSERT_EQ(outputs.size(), 1);