Skip to content

Commit

Permalink
Fix overflow behavior for spark decimal sum aggregate fucntion
Browse files Browse the repository at this point in the history
  • Loading branch information
zhli1142015 committed Sep 30, 2024
1 parent e51c1cc commit d65a468
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 22 deletions.
7 changes: 2 additions & 5 deletions velox/functions/sparksql/aggregates/DecimalSumAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace facebook::velox::functions::aggregate::sparksql {
/// @tparam TInputType The raw input data type.
/// @tparam TSumType The type of sum in the output of partial aggregation or the
/// final output type of final aggregation.
template <typename TInputType, typename TSumType>
template <typename TInputType, typename TSumType, uint8_t ResultPrecision>
class DecimalSumAggregate {
public:
using InputType = Row<TInputType>;
Expand Down Expand Up @@ -78,11 +78,8 @@ class DecimalSumAggregate {
}
auto const adjustedSum =
DecimalUtil::adjustSumForOverflow(sum.value(), overflow);
constexpr uint8_t maxPrecision = std::is_same_v<TSumType, int128_t>
? LongDecimalType::kMaxPrecision
: ShortDecimalType::kMaxPrecision;
if (adjustedSum.has_value() &&
DecimalUtil::valueInPrecisionRange(adjustedSum, maxPrecision)) {
DecimalUtil::valueInPrecisionRange(adjustedSum, ResultPrecision)) {
return adjustedSum;
} else {
// Found overflow during computing adjusted sum.
Expand Down
81 changes: 64 additions & 17 deletions velox/functions/sparksql/aggregates/SumAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,64 @@ void checkAccumulatorRowType(const TypePtr& type) {
type->childAt(0)->isShortDecimal() || type->childAt(0)->isLongDecimal());
VELOX_CHECK_EQ(type->childAt(1)->kind(), TypeKind::BOOLEAN);
}

std::unique_ptr<exec::Aggregate> constructDecimalSumAgg(
const TypePtr& inputType,
const TypePtr& sumType,
const TypePtr& resultType) {
VELOX_CHECK(sumType->isDecimal());
uint8_t precision = 0;
if (sumType->isShortDecimal()) {
precision = sumType->asShortDecimal().precision();
} else {
precision = sumType->asLongDecimal().precision();
}
switch (precision) {
#define PRECISION_CASE(precision) \
case precision: \
if (inputType->isShortDecimal() && sumType->isShortDecimal()) { \
return std::make_unique<exec::SimpleAggregateAdapter< \
DecimalSumAggregate<int64_t, int64_t, precision>>>(resultType); \
} else if (inputType->isShortDecimal() && sumType->isLongDecimal()) { \
return std::make_unique<exec::SimpleAggregateAdapter< \
DecimalSumAggregate<int64_t, int128_t, precision>>>(resultType); \
} else { \
return std::make_unique<exec::SimpleAggregateAdapter< \
DecimalSumAggregate<int128_t, int128_t, precision>>>(resultType); \
}
PRECISION_CASE(11)
PRECISION_CASE(12)
PRECISION_CASE(13)
PRECISION_CASE(14)
PRECISION_CASE(15)
PRECISION_CASE(16)
PRECISION_CASE(17)
PRECISION_CASE(18)
PRECISION_CASE(19)
PRECISION_CASE(20)
PRECISION_CASE(21)
PRECISION_CASE(22)
PRECISION_CASE(23)
PRECISION_CASE(24)
PRECISION_CASE(25)
PRECISION_CASE(26)
PRECISION_CASE(27)
PRECISION_CASE(28)
PRECISION_CASE(29)
PRECISION_CASE(30)
PRECISION_CASE(31)
PRECISION_CASE(32)
PRECISION_CASE(33)
PRECISION_CASE(34)
PRECISION_CASE(35)
PRECISION_CASE(36)
PRECISION_CASE(37)
PRECISION_CASE(38)
#undef PRECISION_CASE
default:
VELOX_UNREACHABLE();
}
}
} // namespace

exec::AggregateRegistrationResult registerSum(
Expand Down Expand Up @@ -100,14 +158,8 @@ exec::AggregateRegistrationResult registerSum(
BIGINT());
case TypeKind::BIGINT: {
if (inputType->isShortDecimal()) {
auto const sumType = getDecimalSumType(resultType);
if (sumType->isShortDecimal()) {
return std::make_unique<exec::SimpleAggregateAdapter<
DecimalSumAggregate<int64_t, int64_t>>>(resultType);
} else if (sumType->isLongDecimal()) {
return std::make_unique<exec::SimpleAggregateAdapter<
DecimalSumAggregate<int64_t, int128_t>>>(resultType);
}
return constructDecimalSumAgg(
inputType, getDecimalSumType(resultType), resultType);
}
return std::make_unique<SumAggregate<int64_t, int64_t, int64_t>>(
BIGINT());
Expand All @@ -116,8 +168,8 @@ exec::AggregateRegistrationResult registerSum(
VELOX_CHECK(inputType->isLongDecimal());
// If inputType is long decimal,
// its output type is always long decimal.
return std::make_unique<exec::SimpleAggregateAdapter<
DecimalSumAggregate<int128_t, int128_t>>>(resultType);
return constructDecimalSumAgg(
inputType, getDecimalSumType(resultType), resultType);
}
case TypeKind::REAL:
if (resultType->kind() == TypeKind::REAL) {
Expand All @@ -138,13 +190,8 @@ exec::AggregateRegistrationResult registerSum(
checkAccumulatorRowType(inputType);
// For the intermediate aggregation step, input intermediate sum
// type is equal to final result sum type.
if (inputType->childAt(0)->isShortDecimal()) {
return std::make_unique<exec::SimpleAggregateAdapter<
DecimalSumAggregate<int64_t, int64_t>>>(resultType);
} else if (inputType->childAt(0)->isLongDecimal()) {
return std::make_unique<exec::SimpleAggregateAdapter<
DecimalSumAggregate<int128_t, int128_t>>>(resultType);
}
return constructDecimalSumAgg(
inputType->childAt(0), inputType->childAt(0), resultType);
}
[[fallthrough]];
default:
Expand Down
24 changes: 24 additions & 0 deletions velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,30 @@ TEST_F(SumAggregationTest, decimalGroupBySumOverflow) {
decimalGroupBySumOverflow(decimalVector);
}

TEST_F(SumAggregationTest, decimalLargeCountRowsOverflow) {
// When the precision of the input type is less than 28, the precision of the
// result type
// will be less than 38. Therefore, we need to check if the result overflows
// due to the result type's precision limitations. This overflow is more
// likely to occur when dealing with a large number of rows. To simulate this
// case in unit testing, we create instances with input values that are very
// close to the overflow threshold, but only for the final step. For example,
// if the input type is Decimal(3, 2) and all values are 1.00, with
// 100,000,000,001 rows and 2 partitions, the final step's input would be
// (50,000,000,000.00, false), (50,000,000,000.00, false).
auto accumulators = makeRowVector({makeRowVector(
{makeFlatVector<int64_t>(
{5000'000'000'100L, 5000'000'000'000L}, DECIMAL(13, 2)),
makeFlatVector<bool>({false, false})})});
auto node = PlanBuilder(pool())
.values({accumulators})
.finalAggregation({}, {"spark_sum(c0)"}, {{DECIMAL(3, 2)}})
.planNode();
auto expected = makeRowVector(
{makeNullableFlatVector<int128_t>({std::nullopt}, DECIMAL(13, 2))});
AssertQueryBuilder(node).assertResults(expected);
}

TEST_F(SumAggregationTest, decimalAllNullValues) {
std::vector<std::optional<int128_t>> allNull(5, std::nullopt);
auto input = makeRowVector(
Expand Down

0 comments on commit d65a468

Please sign in to comment.