Skip to content

Commit

Permalink
Workaround to use int32 gather indices [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
hseok-oh committed Sep 6, 2024
1 parent 37b40bd commit 84cef0e
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions compiler/circle-quantizer/src/QuantizeWeightsLLM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ void QuantizeWeightsLLM::visit(luci::CircleFullyConnected *node)

void QuantizeWeightsLLM::visit(luci::CircleGather *node)
{
if (dynamic_cast<luci::CircleConst *>(node->params()) == nullptr)
return;

if (dynamic_cast<luci::CircleConst *>(node->indices()) != nullptr)
return;

auto input = loco::must_cast<luci::CircleConst *>(node->arg(0));
if (elementsize(input) < _skip_length)
return;
Expand All @@ -139,6 +145,11 @@ void QuantizeWeightsLLM::visit(luci::CircleGather *node)
auto new_weights =
_quant_type == Type::Q4_0 ? quantize_q4_block(input) : quantize_q8_block(input);
node->params(new_weights);

// Workaround: indices to INT32 type
auto indices = loco::must_cast<luci::CircleNode *>(node->indices());
if (indices->dtype() == loco::DataType::S64)
indices->dtype(loco::DataType::S32);
}
}

Expand Down

0 comments on commit 84cef0e

Please sign in to comment.