Skip to content

Commit

Permalink
implement faster floating-point isless (#39090)
Browse files Browse the repository at this point in the history
* implement faster floating-point `isless`

Previously `isless` relied on the C intrinsic `fpislt` in
`src/runtime_intrinsics.c`, while the new implementation in Julia
arguably generates better code, namely:

 1. The NaN-check compiles to a single instruction + branch amenable
    for branch prediction in arguably most usecases (i.e. comparing
    non-NaN floats), thus speeding up execution.
 2. The compiler now often manages to remove NaN-computation if the
    embedding code has already proven the arguments to be non-NaN.
 3. The actual operation compares both arguments as sign-magnitude
    integers instead of case analysis based on the sign of one
    argument. This symmetric treatment may generate vectorized
    instructions for the sign-magnitude conversion depending on how the
    arguments are layed out.

The actual behaviour of `isless` did not change and apart from the
Julia-specific NaN-handling (which may be up for debate) the resulting
total order corresponds to the IEEE-754 specified `totalOrder`.

While the new implementation no longer generates fully branchless code I
did not manage to construct a usecase where this was detrimental: the
saved work seems to outweight the potential cost of a branch
misprediction in all of my tests with various NaN-polluted data. Also
auto-vectorization was not effective on the previous `fpislt` either.

Quick benchmarks (AMD A10-7860K) on `sort`, avoiding the specialized
algorithm:

```julia
a = rand(1000);
@Btime sort($a, lt=(a,b)->isless(a,b));
    # before: 56.030 μs (1 allocation: 7.94 KiB)
    #  after: 40.853 μs (1 allocation: 7.94 KiB)
a = rand(1000000);
@Btime sort($a, lt=(a,b)->isless(a,b));
    # before: 159.499 ms (2 allocations: 7.63 MiB)
    #  after: 120.536 ms (2 allocations: 7.63 MiB)
a = [rand((rand(), NaN)) for _ in 1:1000000];
@Btime sort($a, lt=(a,b)->isless(a,b));
    # before: 111.925 ms (2 allocations: 7.63 MiB)
    #  after:  77.669 ms (2 allocations: 7.63 MiB)
```


* Remove old intrinsic fpslt code

Co-authored-by: Mustafa Mohamad <mus-m@outlook.com>
  • Loading branch information
stev47 and musm authored Apr 25, 2021
1 parent 248c02f commit 79920db
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 50 deletions.
1 change: 0 additions & 1 deletion .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ StatementMacros:
- checked_intrinsic_ctype
- cvt_iintrinsic
- fpiseq_n
- fpislt_n
- ter_fintrinsic
- ter_intrinsic_ctype
- un_fintrinsic
Expand Down
1 change: 0 additions & 1 deletion base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ add_tfunc(ne_float, 2, 2, cmp_tfunc, 2)
add_tfunc(lt_float, 2, 2, cmp_tfunc, 2)
add_tfunc(le_float, 2, 2, cmp_tfunc, 2)
add_tfunc(fpiseq, 2, 2, cmp_tfunc, 1)
add_tfunc(fpislt, 2, 2, cmp_tfunc, 1)
add_tfunc(eq_float_fast, 2, 2, cmp_tfunc, 1)
add_tfunc(ne_float_fast, 2, 2, cmp_tfunc, 1)
add_tfunc(lt_float_fast, 2, 2, cmp_tfunc, 1)
Expand Down
16 changes: 13 additions & 3 deletions base/float.jl
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,19 @@ end
isequal(x::Float16, y::Float16) = fpiseq(x, y)
isequal(x::Float32, y::Float32) = fpiseq(x, y)
isequal(x::Float64, y::Float64) = fpiseq(x, y)
isless( x::Float16, y::Float16) = fpislt(x, y)
isless( x::Float32, y::Float32) = fpislt(x, y)
isless( x::Float64, y::Float64) = fpislt(x, y)

# interpret as sign-magnitude integer
@inline function _fpint(x)
IntT = inttype(typeof(x))
ix = reinterpret(IntT, x)
return ifelse(ix < zero(IntT), ix typemax(IntT), ix)
end

@inline function isless(a::T, b::T) where T<:IEEEFloat
(isnan(a) || isnan(b)) && return !isnan(a)

return _fpint(a) < _fpint(b)
end

# Exact Float (Tf) vs Integer (Ti) comparisons
# Assumes:
Expand Down
21 changes: 0 additions & 21 deletions src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ static void jl_init_intrinsic_functions_codegen(void)
float_func[lt_float_fast] = true;
float_func[le_float_fast] = true;
float_func[fpiseq] = true;
float_func[fpislt] = true;
float_func[abs_float] = true;
float_func[copysign_float] = true;
float_func[ceil_llvm] = true;
Expand Down Expand Up @@ -1165,26 +1164,6 @@ static Value *emit_untyped_intrinsic(jl_codectx_t &ctx, intrinsic f, Value **arg
ctx.builder.CreateICmpEQ(xi, yi));
}

case fpislt: {
*newtyp = jl_bool_type;
Type *it = INTT(t);
Value *xi = ctx.builder.CreateBitCast(x, it);
Value *yi = ctx.builder.CreateBitCast(y, it);
return ctx.builder.CreateOr(
ctx.builder.CreateAnd(
ctx.builder.CreateFCmpORD(x, x),
ctx.builder.CreateFCmpUNO(y, y)),
ctx.builder.CreateAnd(
ctx.builder.CreateFCmpORD(x, y),
ctx.builder.CreateOr(
ctx.builder.CreateAnd(
ctx.builder.CreateICmpSGE(xi, ConstantInt::get(it, 0)),
ctx.builder.CreateICmpSLT(xi, yi)),
ctx.builder.CreateAnd(
ctx.builder.CreateICmpSLT(xi, ConstantInt::get(it, 0)),
ctx.builder.CreateICmpUGT(xi, yi)))));
}

case and_int: return ctx.builder.CreateAnd(x, y);
case or_int: return ctx.builder.CreateOr(x, y);
case xor_int: return ctx.builder.CreateXor(x, y);
Expand Down
1 change: 0 additions & 1 deletion src/intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
ALIAS(lt_float_fast, lt_float) \
ALIAS(le_float_fast, le_float) \
ADD_I(fpiseq, 2) \
ADD_I(fpislt, 2) \
/* bitwise operators */ \
ADD_I(and_int, 2) \
ADD_I(or_int, 2) \
Expand Down
1 change: 0 additions & 1 deletion src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,6 @@ JL_DLLEXPORT jl_value_t *jl_ne_float(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_lt_float(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_le_float(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_fpiseq(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_fpislt(jl_value_t *a, jl_value_t *b);

JL_DLLEXPORT jl_value_t *jl_not_int(jl_value_t *a);
JL_DLLEXPORT jl_value_t *jl_and_int(jl_value_t *a, jl_value_t *b);
Expand Down
22 changes: 0 additions & 22 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -834,33 +834,11 @@ fpiseq_n(double, 64)
#define fpiseq(a,b) \
sizeof(a) == sizeof(float) ? fpiseq32(a, b) : fpiseq64(a, b)

#define fpislt_n(c_type, nbits) \
static inline int fpislt##nbits(c_type a, c_type b) JL_NOTSAFEPOINT \
{ \
bits##nbits ua, ub; \
ua.f = a; \
ub.f = b; \
if (!isnan(a) && isnan(b)) \
return 1; \
if (isnan(a) || isnan(b)) \
return 0; \
if (ua.d >= 0 && ua.d < ub.d) \
return 1; \
if (ua.d < 0 && ua.ud > ub.ud) \
return 1; \
return 0; \
}
fpislt_n(float, 32)
fpislt_n(double, 64)
#define fpislt(a, b) \
sizeof(a) == sizeof(float) ? fpislt32(a, b) : fpislt64(a, b)

bool_fintrinsic(eq,eq_float)
bool_fintrinsic(ne,ne_float)
bool_fintrinsic(lt,lt_float)
bool_fintrinsic(le,le_float)
bool_fintrinsic(fpiseq,fpiseq)
bool_fintrinsic(fpislt,fpislt)

// bitwise operators
#define and_op(a,b) a & b
Expand Down

0 comments on commit 79920db

Please sign in to comment.