|
16 | 16 |
|
17 | 17 | #include "CircleShapeInferenceHelper.h"
|
18 | 18 |
|
| 19 | +#include "Check.h" |
| 20 | + |
19 | 21 | #include <oops/InternalExn.h>
|
20 | 22 |
|
| 23 | +#include <limits> |
| 24 | + |
21 | 25 | using namespace luci::sinf;
|
22 | 26 |
|
23 | 27 | namespace
|
@@ -157,5 +161,57 @@ loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::Tensor
|
157 | 161 | return output_shape;
|
158 | 162 | }
|
159 | 163 |
|
| 164 | +loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::CircleConst *paddings) |
| 165 | +{ |
| 166 | + const loco::DataType S32 = loco::DataType::S32; |
| 167 | + const loco::DataType S64 = loco::DataType::S64; |
| 168 | + |
| 169 | + // TODO support other data type |
| 170 | + LUCI_ASSERT(paddings->dtype() == S32 || paddings->dtype() == S64, "Support int 32/64 for now"); |
| 171 | + LUCI_ASSERT(paddings->rank() == 2, "paddings should be rank 2"); |
| 172 | + |
| 173 | + int32_t n = paddings->dim(0).value(); |
| 174 | + int32_t v = paddings->dim(1).value(); |
| 175 | + |
| 176 | + LUCI_ASSERT(v == 2, "paddings should be [n, 2]"); |
| 177 | + LUCI_ASSERT(n == int32_t(input_shape.rank()), |
| 178 | + "paddings [n, 2] should have same value of input rank"); |
| 179 | + |
| 180 | + loco::TensorShape output_shape; |
| 181 | + |
| 182 | + output_shape.rank(input_shape.rank()); |
| 183 | + for (int32_t ni = 0; ni < n; ++ni) |
| 184 | + { |
| 185 | + if (not input_shape.dim(ni).known()) |
| 186 | + { |
| 187 | + output_shape.dim(ni).unset(); |
| 188 | + continue; |
| 189 | + } |
| 190 | + int32_t idx = ni * 2; |
| 191 | + int value = input_shape.dim(ni).value(); |
| 192 | + if (paddings->dtype() == S32) |
| 193 | + { |
| 194 | + value += paddings->at<S32>(idx + 0); // left |
| 195 | + value += paddings->at<S32>(idx + 1); // right |
| 196 | + } |
| 197 | + else |
| 198 | + { |
| 199 | + auto pl = paddings->at<S64>(idx + 0); |
| 200 | + auto pr = paddings->at<S64>(idx + 1); |
| 201 | + auto max = static_cast<int64_t>(std::numeric_limits<int32_t>::max()); |
| 202 | + auto low = static_cast<int64_t>(std::numeric_limits<int32_t>::lowest()); |
| 203 | + LUCI_ASSERT(pl <= max, "paddings is over 32 bit limit"); |
| 204 | + LUCI_ASSERT(pl >= low, "paddings is over 32 bit limit"); |
| 205 | + LUCI_ASSERT(pr <= max, "paddings is over 32 bit limit"); |
| 206 | + LUCI_ASSERT(pr >= low, "paddings is over 32 bit limit"); |
| 207 | + value += static_cast<int32_t>(pl); // left |
| 208 | + value += static_cast<int32_t>(pr); // right |
| 209 | + } |
| 210 | + output_shape.dim(ni) = value; |
| 211 | + } |
| 212 | + |
| 213 | + return output_shape; |
| 214 | +} |
| 215 | + |
160 | 216 | } // namespace sinf
|
161 | 217 | } // namespace luci
|
0 commit comments