Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support tidb truncate function #9842

Merged
merged 10 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions dbms/src/Flash/Coprocessor/DAGUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,10 @@ const std::unordered_map<tipb::ScalarFuncSig, String> scalar_func_map({
{tipb::ScalarFuncSig::Radians, "radians"},
{tipb::ScalarFuncSig::Sin, "sin"},
{tipb::ScalarFuncSig::Tan, "tan"},
{tipb::ScalarFuncSig::TruncateInt, "trunc"},
{tipb::ScalarFuncSig::TruncateReal, "trunc"},
//{tipb::ScalarFuncSig::TruncateDecimal, "cast"},
{tipb::ScalarFuncSig::TruncateUint, "trunc"},
{tipb::ScalarFuncSig::TruncateInt, "tidbTruncateWithFrac"},
{tipb::ScalarFuncSig::TruncateReal, "tidbTruncateWithFrac"},
{tipb::ScalarFuncSig::TruncateDecimal, "tidbTruncateWithFrac"},
{tipb::ScalarFuncSig::TruncateUint, "tidbTruncateWithFrac"},

{tipb::ScalarFuncSig::LogicalAnd, "and"},
{tipb::ScalarFuncSig::LogicalOr, "or"},
Expand Down
1 change: 1 addition & 0 deletions dbms/src/Functions/FunctionsRound.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ void registerFunctionsRound(FunctionFactory & factory)
factory.registerFunction<FunctionTrunc>("truncate", FunctionFactory::CaseInsensitive);

factory.registerFunction<FunctionTiDBRoundWithFrac>();
factory.registerFunction<FunctionTiDBTruncateWithFrac>();
}

} // namespace DB
183 changes: 146 additions & 37 deletions dbms/src/Functions/FunctionsRound.h
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ struct ConstPowOf10
static_assert(!overflow, "Computation overflows");
};

template <typename InputType, typename OutputType>
template <typename InputType, typename OutputType, bool is_tidb_truncate>
struct TiDBFloatingRound
{
static_assert(std::is_floating_point_v<InputType>);
Expand All @@ -974,8 +974,8 @@ struct TiDBFloatingRound

static OutputType eval(InputType input, FracType frac)
{
// modified from <https://github.com/pingcap/tidb/blob/26237b35f857c2388eab46f9ee3b351687143681/types/helper.go#L33-L48>.

// modified from https://github.com/pingcap/tidb/blob/26237b35f857c2388eab46f9ee3b351687143681/types/helper.go#L33-L48 and
// https://github.com/pingcap/tidb/blob/26237b35f857c2388eab46f9ee3b351687143681/types/helper.go#L50-L61.
Copy link
Contributor

@gengliqi gengliqi Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that ConstPowOf10 can be used to optimize the code in Line 986 base = std::pow(10.0, frac)

Copy link
Contributor Author

@guo-shaoge guo-shaoge Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs more efforts to make sure correctness and performance improvement. Maybe we should to it in another pr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create a new issue to record: #9860

auto value = static_cast<OutputType>(input);
auto base = 1.0;

Expand All @@ -997,9 +997,16 @@ struct TiDBFloatingRound
value = scaled_value;
}

// floating-point environment is thread-local, so `fesetround` is thread-safe.
std::fesetround(FE_TONEAREST);
value = std::nearbyint(value);
if constexpr (is_tidb_truncate)
{
value = std::trunc(value);
}
else
{
// floating-point environment is thread-local, so `fesetround` is thread-safe.
std::fesetround(FE_TONEAREST);
value = std::nearbyint(value);
}

if (frac != 0)
{
Expand All @@ -1018,7 +1025,7 @@ struct TiDBFloatingRound
}
};

template <typename InputType, typename OutputType>
template <typename InputType, typename OutputType, bool is_tidb_truncate>
struct TiDBIntegerRound
{
static_assert(is_integer_v<InputType>);
Expand Down Expand Up @@ -1061,7 +1068,7 @@ struct TiDBIntegerRound
}
}

static OutputType eval(InputType input, FracType frac)
static OutputType evalRound(InputType input, FracType frac)
{
auto value = static_cast<OutputType>(input);

Expand Down Expand Up @@ -1107,6 +1114,33 @@ struct TiDBIntegerRound
return castBack((input < 0), absolute_value);
}
}

static OutputType evalTruncate(InputType input, FracType frac)
{
// modified from https://github.com/pingcap/tidb/blob/807b8923c0181d89d4ea8e4195f9d27d299298a7/pkg/expression/builtin_math.go#L2196-L2219
guo-shaoge marked this conversation as resolved.
Show resolved Hide resolved
const auto value = static_cast<OutputType>(input);
if (frac >= 0)
return value;
else if (frac <= -max_digits)
return 0;
else
{
// To make sure static_cast<OutputType>(Pow::result[-frac]) will not overflow.
assert(Pow::result[-frac] < std::numeric_limits<OutputType>::max());
const auto base = static_cast<OutputType>(Pow::result[-frac]);

const auto remainder = value % base;
return value - remainder;
}
}

static OutputType eval(InputType input, FracType frac)
{
if constexpr (is_tidb_truncate)
return evalTruncate(input, frac);
else
return evalRound(input, frac);
}
};

struct TiDBDecimalRoundInfo
Expand All @@ -1127,7 +1161,7 @@ struct TiDBDecimalRoundInfo
{}
};

template <typename InputType, typename OutputType>
template <typename InputType, typename OutputType, bool is_tidb_truncate>
struct TiDBDecimalRound
{
static_assert(IsDecimal<InputType>);
Expand Down Expand Up @@ -1158,10 +1192,13 @@ struct TiDBDecimalRound
auto remainder = absolute_value % base;

absolute_value -= remainder;
if (remainder >= base / 2)
if constexpr (!is_tidb_truncate)
{
// round up.
absolute_value += base;
if (remainder >= base / 2)
{
// round up.
absolute_value += base;
}
}
}

Expand Down Expand Up @@ -1198,14 +1235,21 @@ struct TiDBDecimalRound

struct TiDBRoundPrecisionInferer
{
static std::tuple<PrecType, ScaleType> infer(PrecType prec, ScaleType scale, FracType frac, bool is_const_frac)
static std::tuple<PrecType, ScaleType> infer(
PrecType prec,
ScaleType scale,
FracType frac,
bool is_const_frac,
bool is_tidb_truncate)
{
assert(prec >= scale);
PrecType int_prec = prec - scale;
ScaleType new_scale = scale;

// +1 for possible overflow, e.g. round(99999.9) => 100000
ScaleType int_prec_increment = 1;
if (is_tidb_truncate)
int_prec_increment = 0;

if (is_const_frac)
{
Expand All @@ -1219,6 +1263,14 @@ struct TiDBRoundPrecisionInferer
}

PrecType new_prec = std::min(decimal_max_prec, int_prec + int_prec_increment + new_scale);
if (new_prec == 0)
{
// new_prec can be zero when the prec is eq to scale and frac is le to zero for truncate:
// select truncate(0.22, 0) from t_col_decimal_2_2;
// Not possible for round, because int_prec_increment is 1 for round.
RUNTIME_CHECK(is_tidb_truncate && is_const_frac && frac <= 0 && prec == scale);
new_prec = 1;
}
return std::make_tuple(new_prec, new_scale);
}
};
Expand All @@ -1239,7 +1291,8 @@ template <
typename OutputType,
typename InputColumn,
typename FracColumn,
typename OutputColumn>
typename OutputColumn,
bool is_tidb_truncate>
struct TiDBRound
{
static void apply(const TiDBRoundArguments & args)
Expand Down Expand Up @@ -1273,11 +1326,14 @@ struct TiDBRound
auto frac_data = frac_column->template getValue<FracType>();

if constexpr (std::is_floating_point_v<InputType>)
output_data[0] = TiDBFloatingRound<InputType, OutputType>::eval(input_data, frac_data);
output_data[0]
= TiDBFloatingRound<InputType, OutputType, is_tidb_truncate>::eval(input_data, frac_data);
else if constexpr (IsDecimal<InputType>)
output_data[0] = TiDBDecimalRound<InputType, OutputType>::eval(input_data, frac_data, info);
output_data[0]
= TiDBDecimalRound<InputType, OutputType, is_tidb_truncate>::eval(input_data, frac_data, info);
else
output_data[0] = TiDBIntegerRound<InputType, OutputType>::eval(input_data, frac_data);
output_data[0]
= TiDBIntegerRound<InputType, OutputType, is_tidb_truncate>::eval(input_data, frac_data);
}
else
{
Expand All @@ -1287,11 +1343,17 @@ struct TiDBRound
for (size_t i = 0; i < size; ++i)
{
if constexpr (std::is_floating_point_v<InputType>)
output_data[i] = TiDBFloatingRound<InputType, OutputType>::eval(input_data, frac_data[i]);
output_data[i] = TiDBFloatingRound<InputType, OutputType, is_tidb_truncate>::eval(
input_data,
frac_data[i]);
else if constexpr (IsDecimal<InputType>)
output_data[i] = TiDBDecimalRound<InputType, OutputType>::eval(input_data, frac_data[i], info);
output_data[i] = TiDBDecimalRound<InputType, OutputType, is_tidb_truncate>::eval(
input_data,
frac_data[i],
info);
else
output_data[i] = TiDBIntegerRound<InputType, OutputType>::eval(input_data, frac_data[i]);
output_data[i]
= TiDBIntegerRound<InputType, OutputType, is_tidb_truncate>::eval(input_data, frac_data[i]);
}
}
}
Expand All @@ -1305,11 +1367,17 @@ struct TiDBRound
for (size_t i = 0; i < size; ++i)
{
if constexpr (std::is_floating_point_v<InputType>)
output_data[i] = TiDBFloatingRound<InputType, OutputType>::eval(input_data[i], frac_data);
output_data[i] = TiDBFloatingRound<InputType, OutputType, is_tidb_truncate>::eval(
input_data[i],
frac_data);
else if constexpr (IsDecimal<InputType>)
output_data[i] = TiDBDecimalRound<InputType, OutputType>::eval(input_data[i], frac_data, info);
output_data[i] = TiDBDecimalRound<InputType, OutputType, is_tidb_truncate>::eval(
input_data[i],
frac_data,
info);
else
output_data[i] = TiDBIntegerRound<InputType, OutputType>::eval(input_data[i], frac_data);
output_data[i]
= TiDBIntegerRound<InputType, OutputType, is_tidb_truncate>::eval(input_data[i], frac_data);
}
}
else
Expand All @@ -1320,27 +1388,34 @@ struct TiDBRound
for (size_t i = 0; i < size; ++i)
{
if constexpr (std::is_floating_point_v<InputType>)
output_data[i] = TiDBFloatingRound<InputType, OutputType>::eval(input_data[i], frac_data[i]);
output_data[i] = TiDBFloatingRound<InputType, OutputType, is_tidb_truncate>::eval(
input_data[i],
frac_data[i]);
else if constexpr (IsDecimal<InputType>)
output_data[i]
= TiDBDecimalRound<InputType, OutputType>::eval(input_data[i], frac_data[i], info);
output_data[i] = TiDBDecimalRound<InputType, OutputType, is_tidb_truncate>::eval(
input_data[i],
frac_data[i],
info);
else
output_data[i] = TiDBIntegerRound<InputType, OutputType>::eval(input_data[i], frac_data[i]);
output_data[i] = TiDBIntegerRound<InputType, OutputType, is_tidb_truncate>::eval(
input_data[i],
frac_data[i]);
}
}
}
}
};

/**
* round(x, d) for TiDB.
* round(x, d) and truncate(x, d) for TiDB.
*/
class FunctionTiDBRoundWithFrac : public IFunction
template <typename Name, bool is_tidb_truncate>
class FunctionTiDBRoundImpl : public IFunction
{
public:
static constexpr auto name = "tidbRoundWithFrac";
static constexpr auto name = Name::name;

static FunctionPtr create(const Context &) { return std::make_shared<FunctionTiDBRoundWithFrac>(); }
static FunctionPtr create(const Context &) { return std::make_shared<FunctionTiDBRoundImpl>(); }

String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
Expand All @@ -1355,7 +1430,6 @@ class FunctionTiDBRoundWithFrac : public IFunction
// non-const frac column can generate different return types. Plese see TiDBRoundPrecisionInferer for details.
bool useDefaultImplementationForConstants() const override { return false; }

private:
static FracType getFracFromConstColumn(const ColumnConst * column)
{
using UnsignedFrac = make_unsigned_t<FracType>;
Expand Down Expand Up @@ -1383,6 +1457,7 @@ class FunctionTiDBRoundWithFrac : public IFunction
}
}

private:
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
checkArguments(arguments);
Expand Down Expand Up @@ -1430,7 +1505,8 @@ class FunctionTiDBRoundWithFrac : public IFunction
else
is_const_frac = false;

auto [new_prec, new_scale] = TiDBRoundPrecisionInferer::infer(prec, scale, frac, is_const_frac);
auto [new_prec, new_scale]
= TiDBRoundPrecisionInferer::infer(prec, scale, frac, is_const_frac, is_tidb_truncate);
return createDecimal(new_prec, new_scale);
}
}
Expand Down Expand Up @@ -1566,16 +1642,44 @@ class FunctionTiDBRoundWithFrac : public IFunction
if (args.input_column->isColumnConst())
{
if (args.frac_column->isColumnConst())
TiDBRound<InputType, FracType, OutputType, ColumnConst, ColumnConst, OutputColumn>::apply(args);
TiDBRound<
InputType,
FracType,
OutputType,
ColumnConst,
ColumnConst,
OutputColumn,
is_tidb_truncate>::apply(args);
else
TiDBRound<InputType, FracType, OutputType, ColumnConst, FracColumn, OutputColumn>::apply(args);
TiDBRound<
InputType,
FracType,
OutputType,
ColumnConst,
FracColumn,
OutputColumn,
is_tidb_truncate>::apply(args);
}
else
{
if (args.frac_column->isColumnConst())
TiDBRound<InputType, FracType, OutputType, InputColumn, ColumnConst, OutputColumn>::apply(args);
TiDBRound<
InputType,
FracType,
OutputType,
InputColumn,
ColumnConst,
OutputColumn,
is_tidb_truncate>::apply(args);
else
TiDBRound<InputType, FracType, OutputType, InputColumn, FracColumn, OutputColumn>::apply(args);
TiDBRound<
InputType,
FracType,
OutputType,
InputColumn,
FracColumn,
OutputColumn,
is_tidb_truncate>::apply(args);
}

return true;
Expand Down Expand Up @@ -1622,6 +1726,9 @@ struct NameRoundDecimalToInt { static constexpr auto name = "roundDecimalToInt";
struct NameCeilDecimalToInt { static constexpr auto name = "ceilDecimalToInt"; };
struct NameFloorDecimalToInt { static constexpr auto name = "floorDecimalToInt"; };
struct NameTruncDecimalToInt { static constexpr auto name = "truncDecimalToInt"; };

struct NameTiDBRoundWithFrac { static constexpr auto name = "tidbRoundWithFrac"; };
struct NameTiDBTruncateWithFrac { static constexpr auto name = "tidbTruncateWithFrac"; };
// clang-format on

using FunctionRoundToExp2 = FunctionUnaryArithmetic<RoundToExp2Impl, NameRoundToExp2, false>;
Expand All @@ -1638,6 +1745,8 @@ using FunctionCeilDecimalToInt = FunctionRoundingDecimalToInt<NameCeilDecimalToI
using FunctionFloorDecimalToInt = FunctionRoundingDecimalToInt<NameFloorDecimalToInt, RoundingMode::Floor>;
using FunctionTruncDecimalToInt = FunctionRoundingDecimalToInt<NameTruncDecimalToInt, RoundingMode::Trunc>;

using FunctionTiDBRoundWithFrac = FunctionTiDBRoundImpl<NameTiDBRoundWithFrac, /*is_tidb_truncate=*/false>;
using FunctionTiDBTruncateWithFrac = FunctionTiDBRoundImpl<NameTiDBTruncateWithFrac, /*is_tidb_truncate=*/true>;

struct PositiveMonotonicity
{
Expand Down
4 changes: 2 additions & 2 deletions dbms/src/Functions/FunctionsString.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5893,9 +5893,9 @@ class FormatImpl : public IFunction
const TiDBDecimalRoundInfo & info [[maybe_unused]])
{
if constexpr (IsDecimal<T>)
return TiDBDecimalRound<T, T>::eval(number, max_num_decimals, info);
return TiDBDecimalRound<T, T, /*is_tidb_truncate=*/false>::eval(number, max_num_decimals, info);
else if constexpr (std::is_floating_point_v<T>)
return TiDBFloatingRound<T, Float64>::eval(number, max_num_decimals);
return TiDBFloatingRound<T, Float64, /*is_tidb_truncate=*/false>::eval(number, max_num_decimals);
else
{
static_assert(std::is_integral_v<T>);
Expand Down
Loading