Skip to content

Commit 7344b2e

Browse files
authored
[luci] Migrate helperPads to ShapeInferenceHelper (Samsung#13862)
This PR migrates heplerPads.h to CircleShapeInferenceHelper.h and CircleShapeInferenceHelper.cpp. ONE-DCO-1.0-Signed-off-by: JuYoung Lee rsb98759@gmail.com
1 parent 9d1354e commit 7344b2e

File tree

4 files changed

+63
-90
lines changed

4 files changed

+63
-90
lines changed

compiler/luci/service/src/CircleShapeInferenceHelper.cpp

+56
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@
1616

1717
#include "CircleShapeInferenceHelper.h"
1818

19+
#include "Check.h"
20+
1921
#include <oops/InternalExn.h>
2022

23+
#include <limits>
24+
2125
using namespace luci::sinf;
2226

2327
namespace
@@ -157,5 +161,57 @@ loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::Tensor
157161
return output_shape;
158162
}
159163

164+
loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::CircleConst *paddings)
165+
{
166+
const loco::DataType S32 = loco::DataType::S32;
167+
const loco::DataType S64 = loco::DataType::S64;
168+
169+
// TODO support other data type
170+
LUCI_ASSERT(paddings->dtype() == S32 || paddings->dtype() == S64, "Support int 32/64 for now");
171+
LUCI_ASSERT(paddings->rank() == 2, "paddings should be rank 2");
172+
173+
int32_t n = paddings->dim(0).value();
174+
int32_t v = paddings->dim(1).value();
175+
176+
LUCI_ASSERT(v == 2, "paddings should be [n, 2]");
177+
LUCI_ASSERT(n == int32_t(input_shape.rank()),
178+
"paddings [n, 2] should have same value of input rank");
179+
180+
loco::TensorShape output_shape;
181+
182+
output_shape.rank(input_shape.rank());
183+
for (int32_t ni = 0; ni < n; ++ni)
184+
{
185+
if (not input_shape.dim(ni).known())
186+
{
187+
output_shape.dim(ni).unset();
188+
continue;
189+
}
190+
int32_t idx = ni * 2;
191+
int value = input_shape.dim(ni).value();
192+
if (paddings->dtype() == S32)
193+
{
194+
value += paddings->at<S32>(idx + 0); // left
195+
value += paddings->at<S32>(idx + 1); // right
196+
}
197+
else
198+
{
199+
auto pl = paddings->at<S64>(idx + 0);
200+
auto pr = paddings->at<S64>(idx + 1);
201+
auto max = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
202+
auto low = static_cast<int64_t>(std::numeric_limits<int32_t>::lowest());
203+
LUCI_ASSERT(pl <= max, "paddings is over 32 bit limit");
204+
LUCI_ASSERT(pl >= low, "paddings is over 32 bit limit");
205+
LUCI_ASSERT(pr <= max, "paddings is over 32 bit limit");
206+
LUCI_ASSERT(pr >= low, "paddings is over 32 bit limit");
207+
value += static_cast<int32_t>(pl); // left
208+
value += static_cast<int32_t>(pr); // right
209+
}
210+
output_shape.dim(ni) = value;
211+
}
212+
213+
return output_shape;
214+
}
215+
160216
} // namespace sinf
161217
} // namespace luci

compiler/luci/service/src/CircleShapeInferenceHelper.h

+4
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ loco::TensorShape circle_shape(const luci::CircleNode *node);
4848
// Throw an exception if x and y are not broadcastable.
4949
loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y);
5050

51+
// Return shape of pad ops using paddings.
52+
loco::TensorShape pad_shape(const loco::TensorShape &input_shape,
53+
const luci::CircleConst *paddings);
54+
5155
/**
5256
* @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
5357
*

compiler/luci/service/src/HelperPads.h

-88
This file was deleted.

compiler/luci/service/src/Nodes/CirclePad.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
#include "CircleCloneNode.h"
2020
#include "CircleShapeInferenceHelper.h"
21-
#include "HelperPads.h"
2221

2322
namespace luci
2423
{
@@ -35,7 +34,9 @@ loco::TensorShape Algorithm::visit(const luci::CirclePad *node)
3534
{
3635
// TODO support non-const case
3736
auto paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
38-
return use_paddings(node, paddings);
37+
auto circle_input = loco::must_cast<const luci::CircleNode *>(node->input());
38+
auto input_shape = circle_shape(circle_input);
39+
return pad_shape(input_shape, paddings);
3940
}
4041

4142
} // namespace sinf

0 commit comments

Comments
 (0)