Skip to content

Commit

Permalink
Change min from llvmcall to intrinsic
Browse files Browse the repository at this point in the history
  • Loading branch information
gbaraldi authored and giordano committed Jan 13, 2025
1 parent 5255bcc commit 4e18a86
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 8 deletions.
4 changes: 4 additions & 0 deletions Compiler/src/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ add_tfunc(add_float, 2, 2, math_tfunc, 2)
add_tfunc(sub_float, 2, 2, math_tfunc, 2)
add_tfunc(mul_float, 2, 2, math_tfunc, 8)
add_tfunc(div_float, 2, 2, math_tfunc, 10)
add_tfunc(min_float, 2, 2, math_tfunc, 1)
add_tfunc(max_float, 2, 2, math_tfunc, 1)
add_tfunc(fma_float, 3, 3, math_tfunc, 8)
add_tfunc(muladd_float, 3, 3, math_tfunc, 8)

Expand All @@ -198,6 +200,8 @@ add_tfunc(add_float_fast, 2, 2, math_tfunc, 2)
add_tfunc(sub_float_fast, 2, 2, math_tfunc, 2)
add_tfunc(mul_float_fast, 2, 2, math_tfunc, 8)
add_tfunc(div_float_fast, 2, 2, math_tfunc, 10)
add_tfunc(min_float_fast, 2, 2, math_tfunc, 1)
add_tfunc(max_float_fast, 2, 2, math_tfunc, 1)

# bitwise operators
# -----------------
Expand Down
10 changes: 4 additions & 6 deletions base/fastmath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ module FastMath
export @fastmath

import Core.Intrinsics: sqrt_llvm_fast, neg_float_fast,
add_float_fast, sub_float_fast, mul_float_fast, div_float_fast,
add_float_fast, sub_float_fast, mul_float_fast, div_float_fast, min_float_fast, max_float_fast,
eq_float_fast, ne_float_fast, lt_float_fast, le_float_fast
import Base: afoldl

Expand Down Expand Up @@ -168,6 +168,9 @@ add_fast(x::T, y::T) where {T<:FloatTypes} = add_float_fast(x, y)
sub_fast(x::T, y::T) where {T<:FloatTypes} = sub_float_fast(x, y)
mul_fast(x::T, y::T) where {T<:FloatTypes} = mul_float_fast(x, y)
div_fast(x::T, y::T) where {T<:FloatTypes} = div_float_fast(x, y)
max_fast(x::T, y::T) where {T<:FloatTypes} = max_float_fast(x, y)
min_fast(x::T, y::T) where {T<:FloatTypes} = min_float_fast(x, y)
minmax_fast(x::T, y::T) where {T<:FloatTypes} = (min_fast(x, y), max_fast(x, y))

@fastmath begin
cmp_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(x==y, 0, ifelse(x<y, -1, +1))
Expand Down Expand Up @@ -236,11 +239,6 @@ ComplexTypes = Union{ComplexF32, ComplexF64}

ne_fast(x::T, y::T) where {T<:ComplexTypes} = !(x==y)

# Note: we use the same comparison for min, max, and minmax, so
# that the compiler can convert between them
max_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, y, x)
min_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, x, y)
minmax_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, (x,y), (y,x))
end

# fall-back implementations and type promotion
Expand Down
34 changes: 33 additions & 1 deletion src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,14 @@ const auto &float_func() {
float_func[sub_float] = true;
float_func[mul_float] = true;
float_func[div_float] = true;
float_func[min_float] = true;
float_func[max_float] = true;
float_func[add_float_fast] = true;
float_func[sub_float_fast] = true;
float_func[mul_float_fast] = true;
float_func[div_float_fast] = true;
float_func[min_float_fast] = true;
float_func[max_float_fast] = true;
float_func[fma_float] = true;
float_func[muladd_float] = true;
float_func[eq_float] = true;
Expand Down Expand Up @@ -134,7 +138,7 @@ uint32_t jl_get_LLVM_VERSION_impl(void)
the bitcast function does nothing except change the type tag
of a value. At the user-level, it is perhaps better known as reinterpret.
boxing is delayed until absolutely necessary, and handled at the point
where the box is needed.
where the box is nefeded.
all intrinsics have a non-compiled implementation, this file contains
the optimizations for handling them unboxed
*/
Expand Down Expand Up @@ -1490,6 +1494,34 @@ static Value *emit_untyped_intrinsic(jl_codectx_t &ctx, intrinsic f, ArrayRef<Va
case sub_float: return math_builder(ctx)().CreateFSub(x, y);
case mul_float: return math_builder(ctx)().CreateFMul(x, y);
case div_float: return math_builder(ctx)().CreateFDiv(x, y);
case min_float: {
assert(x->getType() == y->getType());
FunctionCallee minintr = Intrinsic::getDeclaration(jl_Module, Intrinsic::minimum, ArrayRef<Type*>(t));
return ctx.builder.CreateCall(minintr, {x, y});
}
case max_float: {
assert(x->getType() == y->getType());
FunctionCallee maxintr = Intrinsic::getDeclaration(jl_Module, Intrinsic::maximum, ArrayRef<Type*>(t));
return ctx.builder.CreateCall(maxintr, {x, y});
}
case min_float_fast: {
assert(x->getType() == y->getType());
FunctionCallee minintr = Intrinsic::getDeclaration(jl_Module, Intrinsic::minimum, ArrayRef<Type*>(t));
auto call = ctx.builder.CreateCall(minintr, {x, y});
auto fmf = call->getFastMathFlags();
fmf.setFast();
call->copyFastMathFlags(fmf);
return call;
}
case max_float_fast: {
assert(x->getType() == y->getType());
FunctionCallee maxintr = Intrinsic::getDeclaration(jl_Module, Intrinsic::maximum, ArrayRef<Type*>(t));
auto call = ctx.builder.CreateCall(maxintr, {x, y});
auto fmf = call->getFastMathFlags();
fmf.setFast();
call->copyFastMathFlags(fmf);
return call;
}
case add_float_fast: return math_builder(ctx, true)().CreateFAdd(x, y);
case sub_float_fast: return math_builder(ctx, true)().CreateFSub(x, y);
case mul_float_fast: return math_builder(ctx, true)().CreateFMul(x, y);
Expand Down
4 changes: 4 additions & 0 deletions src/intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
ADD_I(sub_float, 2) \
ADD_I(mul_float, 2) \
ADD_I(div_float, 2) \
ADD_I(min_float, 2) \
ADD_I(max_float, 2) \
ADD_I(fma_float, 3) \
ADD_I(muladd_float, 3) \
/* fast arithmetic */ \
Expand All @@ -25,6 +27,8 @@
ALIAS(sub_float_fast, sub_float) \
ALIAS(mul_float_fast, mul_float) \
ALIAS(div_float_fast, div_float) \
ALIAS(min_float_fast, min_float) \
ALIAS(max_float_fast, max_float) \
/* same-type comparisons */ \
ADD_I(eq_int, 2) \
ADD_I(ne_int, 2) \
Expand Down
2 changes: 2 additions & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1595,6 +1595,8 @@ JL_DLLEXPORT jl_value_t *jl_add_float(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_sub_float(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_mul_float(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_div_float(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_min_float(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_max_float(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_fma_float(jl_value_t *a, jl_value_t *b, jl_value_t *c);
JL_DLLEXPORT jl_value_t *jl_muladd_float(jl_value_t *a, jl_value_t *b, jl_value_t *c);

Expand Down
39 changes: 38 additions & 1 deletion src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -1398,13 +1398,50 @@ bi_iintrinsic_fast(LLVMURem, rem, urem_int, u)
bi_iintrinsic_fast(jl_LLVMSMod, smod, smod_int, )
#define frem(a, b) \
fp_select2(a, b, fmod)

un_fintrinsic(neg_float,neg_float)
bi_fintrinsic(add,add_float)
bi_fintrinsic(sub,sub_float)
bi_fintrinsic(mul,mul_float)
bi_fintrinsic(div,div_float)

float min_float(float x, float y)
{
float diff = x - y;
float argmin = signbit(diff) ? x : y;
int is_nan = isnan(x) || isnan(y);
return is_nan ? diff : argmin;
}

double min_double(double x, double y)
{
double diff = x - y;
double argmin = signbit(diff) ? x : y;
int is_nan = isnan(x) || isnan(y);
return is_nan ? diff : argmin;
}

#define _min(a, b) sizeof(a) == sizeof(float) ? min_float(a, b) : min_double(a, b)
bi_fintrinsic(_min, min_float)

float max_float(float x, float y)
{
float diff = x - y;
float argmin = signbit(diff) ? y : x;
int is_nan = isnan(x) || isnan(y);
return is_nan ? diff : argmin;
}

double max_double(double x, double y)
{
double diff = x - y;
double argmin = signbit(diff) ? x : y;
int is_nan = isnan(x) || isnan(y);
return is_nan ? diff : argmin;
}

#define _max(a, b) sizeof(a) == sizeof(float) ? max_float(a, b) : max_double(a, b)
bi_fintrinsic(_max, max_float)

// ternary operators //
// runtime fma is broken on windows, define julia_fma(f) ourself with fma_emulated as reference.
#if defined(_OS_WINDOWS_)
Expand Down

0 comments on commit 4e18a86

Please sign in to comment.